Ejemplo n.º 1
0
def load_model (model_prefix, word_vocab, batch_size):
    FLAGS = load_namespace(model_prefix + ".config.json")
    label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
    num_classes = label_vocab.size()
    best_path = model_prefix + ".best.model"
    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                        dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                        lambda_l2=FLAGS.lambda_l2,
                                        context_lstm_dim=FLAGS.context_lstm_dim,
                                        is_training=False, batch_size=batch_size)
        vars_ = {}
        print ("ValidGraph Build")
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        config = tf.ConfigProto(intra_op_parallelism_threads=0,
                                inter_op_parallelism_threads=0,
                                allow_soft_placement=True)

        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, best_path)
        return valid_graph, sess, label_vocab, FLAGS
Ejemplo n.º 2
0
    wo_maxpool_match = False
    if hasattr(FLAGS, 'wo_maxpool_match'):
        wo_maxpool_match = FLAGS.wo_maxpool_match

    wo_attentive_match = False
    if hasattr(FLAGS, 'wo_attentive_match'):
        wo_attentive_match = FLAGS.wo_attentive_match

    wo_max_attentive_match = False
    if hasattr(FLAGS, 'wo_max_attentive_match'):
        wo_max_attentive_match = FLAGS.wo_max_attentive_match

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    if with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
    if with_NER:
        NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
    char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
    print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.')
    parser.add_argument('--in_path', type=str, default='../data_quora/dev.tsv', help='the path to the test file.')
    parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.')
    parser.add_argument('--word_vec_path', type=str, default='../data_quora/wordvec.txt', help='word embedding file for the input file.')

    args, unparsed = parser.parse_known_args()
    
    # load the configuration file
    print('Loading configurations.')
    options = namespace_utils.load_namespace(args.model_prefix + ".config.json")

    if args.word_vec_path is None: args.word_vec_path = options.word_vec_path

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(args.word_vec_path, fileformat='txt3')
    label_vocab = Vocab(args.model_prefix + ".label_vocab", fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    char_vocab = None
    if options.with_char:
        char_vocab = Vocab(args.model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    
    print('Build SentenceMatchDataStream ... ')
    testDataStream = SentenceMatchDataStream(args.in_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                            label_vocab=label_vocab,
                                            isShuffle=False, isLoop=True, isSort=True, options=options)
    print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance()))
Ejemplo n.º 4
0
def main_cv(FLAGS):
    # np.random.seed(FLAGS.seed)

    for fold in range(FLAGS.cv_folds):
        print("Start training fold " + str(fold))
        train_path = FLAGS.cv_train_path + str(fold) + '.tsv'
        train_feat_path = FLAGS.cv_train_feat_path + str(fold) + '.tsv'
        dev_path = FLAGS.cv_dev_path + str(fold) + '.tsv'
        dev_feat_path = FLAGS.cv_dev_feat_path + str(fold) + '.tsv'
        word_vec_path = FLAGS.word_vec_path
        char_vec_path = FLAGS.char_vec_path
        log_dir = FLAGS.model_dir + '/cv_fold_' + str(fold)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

        # build vocabs
        word_vocab = Vocab(word_vec_path, fileformat='txt3')
        char_vocab = None

        best_path = path_prefix + '.best.model'
        char_path = path_prefix + ".char_vocab"
        label_path = path_prefix + ".label_vocab"
        has_pre_trained_model = False

        if os.path.exists(best_path + ".index"):
            has_pre_trained_model = True
            print('Loading vocabs from a pre-trained model ...')
            label_vocab = Vocab(label_path, fileformat='txt2')
            if FLAGS.with_char:
                char_vocab = Vocab(char_path, fileformat='txt2')
        else:
            print('Collecting words, chars and labels ...')
            (all_words, all_chars, all_labels, all_POSs,
             all_NERs) = collect_vocabs(train_path)
            print('Number of words: {}'.format(len(all_words)))
            label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
            label_vocab.dump_to_txt2(label_path)

            if FLAGS.with_char:
                print('Number of chars: {}'.format(len(all_chars)))
                if char_vec_path == "":
                    char_vocab = Vocab(fileformat='voc',
                                       voc=all_chars,
                                       dim=FLAGS.char_emb_dim)
                else:
                    char_vocab = Vocab(char_vec_path, fileformat='txt3')
                char_vocab.dump_to_txt2(char_path)

        print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            print('char_vocab shape is {}'.format(char_vocab.word_vecs.shape))
        num_classes = label_vocab.size()
        print("Number of labels: {}".format(num_classes))
        sys.stdout.flush()

        print('Build SentenceMatchDataStream ... ')
        trainDataStream = SentenceMatchDataStream(train_path,
                                                  train_feat_path,
                                                  word_vocab=word_vocab,
                                                  char_vocab=char_vocab,
                                                  label_vocab=label_vocab,
                                                  isShuffle=True,
                                                  isLoop=True,
                                                  isSort=True,
                                                  options=FLAGS)
        print('Number of instances in trainDataStream: {}'.format(
            trainDataStream.get_num_instance()))
        print('Number of batches in trainDataStream: {}'.format(
            trainDataStream.get_num_batch()))
        sys.stdout.flush()

        devDataStream = SentenceMatchDataStream(dev_path,
                                                dev_feat_path,
                                                word_vocab=word_vocab,
                                                char_vocab=char_vocab,
                                                label_vocab=label_vocab,
                                                isShuffle=False,
                                                isLoop=True,
                                                isSort=True,
                                                options=FLAGS)
        print('Number of instances in devDataStream: {}'.format(
            devDataStream.get_num_instance()))
        print('Number of batches in devDataStream: {}'.format(
            devDataStream.get_num_batch()))
        sys.stdout.flush()

        init_scale = 0.01
        with tf.Graph().as_default():
            # tf.set_random_seed(FLAGS.seed)

            initializer = tf.random_uniform_initializer(
                -init_scale, init_scale)
            global_step = tf.train.get_or_create_global_step()
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = SentenceMatchModelGraph(num_classes,
                                                      word_vocab=word_vocab,
                                                      char_vocab=char_vocab,
                                                      is_training=True,
                                                      options=FLAGS,
                                                      global_step=global_step)

            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = SentenceMatchModelGraph(num_classes,
                                                      word_vocab=word_vocab,
                                                      char_vocab=char_vocab,
                                                      is_training=False,
                                                      options=FLAGS)

            initializer = tf.global_variables_initializer()
            initializer_local = tf.local_variables_initializer()
            vars_ = {}
            for var in tf.global_variables():
                if "word_embedding" in var.name: continue
                # if not var.name.startswith("Model"): continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)

            sess = tf.Session()
            sess.run(initializer,
                     feed_dict={
                         train_graph.w_embedding: word_vocab.word_vecs,
                         train_graph.c_embedding: char_vocab.word_vecs
                     })
            sess.run(initializer_local)
            if has_pre_trained_model:
                print("Restoring model from " + best_path)
                saver.restore(sess, best_path)
                print("DONE!")

            # training
            train(sess, saver, train_graph, valid_graph, trainDataStream,
                  devDataStream, FLAGS, best_path, label_vocab)
        print()
Ejemplo n.º 5
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(
        FLAGS.train_path)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading dev set.')
    devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(
        FLAGS.test_path)
    print('Number of dev samples: {}'.format(len(devset)))

    if FLAGS.finetune_path != "":
        print('Loading finetune set.')
        ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file(
            FLAGS.finetune_path)
        print('Number of finetune samples: {}'.format(len(ftset)))
    else:
        ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0)

    max_node = max(trn_node, tst_node, ft_node)
    max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh)
    max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh)
    max_sent = max(trn_sent, tst_sent, ft_sent)
    print('Max node number: {}, while max allowed is {}'.format(
        max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(
        max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(
        max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max answer length: {}, truncated to {}'.format(
        max_sent, FLAGS.max_answer_len))

    word_vocab = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab",
                                fileformat='txt2')
    else:
        print('Collecting vocabs.')
        (allWords, allChars,
         allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        edgelabel_vocab = Vocab(voc=allEdgelabels,
                                dim=FLAGS.edgelabel_dim,
                                fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    edgelabel_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = G2S_data_stream.G2SDataStream(devset,
                                                  word_vocab,
                                                  char_vocab,
                                                  edgelabel_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    if ftset != None:
        ftDataStream = G2S_data_stream.G2SDataStream(ftset,
                                                     word_vocab,
                                                     char_vocab,
                                                     edgelabel_vocab,
                                                     options=FLAGS,
                                                     isShuffle=True,
                                                     isLoop=True,
                                                     isSort=True)
        print('Number of instances in ftDataStream: {}'.format(
            ftDataStream.get_num_instance()))
        print('Number of batches in ftDataStream: {}'.format(
            ftDataStream.get_num_batch()))

    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in ('ce_train', 'rl_train', 'transformer')
        valid_mode = 'evaluate' if FLAGS.mode in (
            'ce_train', 'transformer') else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                sys.stdout.flush()
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                sys.stdout.flush()
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode in ('ce_train', 'rl_train',
                              'transformer') and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_subsample(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode in ('ce_train', 'rl_train', 'transformer'):
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
                    (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                if ftset != None:
                    best_accu, best_bleu = fine_tune(sess, saver, FLAGS,
                                                     log_file, ftDataStream,
                                                     devDataStream,
                                                     train_graph, valid_graph,
                                                     path_prefix, best_accu,
                                                     best_bleu)
                else:
                    best_accu, best_bleu = validate_and_save(
                        sess, saver, FLAGS, log_file, devDataStream,
                        valid_graph, path_prefix, best_accu, best_bleu)
                start_time = time.time()

    log_file.close()
Ejemplo n.º 6
0
    mode = args.mode

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)
    if args.beam_size != -1:
        FLAGS.beam_size = args.beam_size

    # load vocabs
    print('Loading vocabs.')
    enc_word_vocab = dec_word_vocab = char_vocab = POS_vocab = NER_vocab = None
    if FLAGS.with_word:
        enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
        print('enc_word_vocab: {}'.format(enc_word_vocab.word_vecs.shape))
        dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
        print('dec_word_vocab: {}'.format(dec_word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    if FLAGS.with_NER:
        NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
Ejemplo n.º 7
0
    model_prefix = args.model_prefix
    in_path = args.in_path
    cache_size = args.cache_size
    use_dep = args.decode

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = char_vocab = POS_vocab = NER_vocab = None
    word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2')
    print('action_vocab: {}'.format(action_vocab.word_vecs.shape))
    feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2')
    print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape))

    print('Loading test set.')
    if use_dep:
        testset = NP2P_data_stream.read_Testset(in_path)
Ejemplo n.º 8
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_from_fof(
            FLAGS.train_path, FLAGS)
    else:
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_file(
            FLAGS.train_path, FLAGS)

    random.shuffle(trainset)
    devset = trainset[:200]
    trainset = trainset[200:]

    print('Number of training samples: {}'.format(len(trainset)))
    print('Number of dev samples: {}'.format(len(devset)))

    max_node = trn_node
    max_in_neigh = trn_in_neigh
    max_out_neigh = trn_out_neigh
    max_sent = trn_sent
    print('Max node number: {}, while max allowed is {}'.format(
        max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(
        max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(
        max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max entity size: {}, truncated to {}'.format(
        max_sent, FLAGS.max_entity_size))

    word_vocab = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab",
                                fileformat='txt2')
    else:
        print('Collecting vocabs.')
        (allWords, allChars,
         allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        edgelabel_vocab = Vocab(voc=allEdgelabels,
                                dim=FLAGS.edgelabel_dim,
                                fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    edgelabel_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=False)

    devDataStream = G2S_data_stream.G2SDataStream(devset,
                                                  word_vocab,
                                                  char_vocab,
                                                  edgelabel_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=False)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode='train')

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode='evaluate')

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        last_step = 0
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            _, loss_value, _ = train_graph.execute(sess,
                                                   cur_batch,
                                                   FLAGS,
                                                   is_train=True)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss / (step - last_step), duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss /
                                (step - last_step), duration))
                sys.stdout.flush()
                log_file.flush()
                last_step = step
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                dev_loss = res_dict['dev_loss']
                dev_accu = res_dict['dev_accu']
                dev_right = int(res_dict['dev_right'])
                dev_total = int(res_dict['dev_total'])
                print('Dev loss = %.4f' % dev_loss)
                log_file.write('Dev loss = %.4f\n' % dev_loss)
                print('Dev accu = %.4f %d/%d' %
                      (dev_accu, dev_right, dev_total))
                log_file.write('Dev accu = %.4f %d/%d\n' %
                               (dev_accu, dev_right, dev_total))
                log_file.flush()
                if best_accu < dev_accu:
                    print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                          format(best_accu, dev_accu))
                    saver.save(sess, best_path)
                    best_accu = dev_accu
                    FLAGS.best_accu = dev_accu
                    namespace_utils.save_namespace(
                        FLAGS, path_prefix + ".config.json")
                    json.dump(res_dict['data'], open(FLAGS.output_path, 'w'))
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
Ejemplo n.º 9
0
    parser.add_argument('--word_vec_path',
                        type=str,
                        help='word embedding file for the input file.')

    args, unparsed = parser.parse_known_args()

    # load the configuration file
    print('Loading configurations.')
    options = namespace_utils.load_namespace(args.model_prefix +
                                             "ESIM.xnli.config.json")

    if args.word_vec_path is None: args.word_vec_path = options.word_vec_path

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(args.word_vec_path, fileformat='txt3')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))

    print('Build DataStream ... ')
    testDataStream = DataStream(args.in_path,
                                word_vocab=word_vocab,
                                label_vocab=None,
                                isShuffle=False,
                                isLoop=True,
                                isSort=True,
                                options=options)
    print('Number of instances in devDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(
        testDataStream.get_num_batch()))
    sys.stdout.flush()
Ejemplo n.º 10
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    if os.path.exists(best_path):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)
                                    
    devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)

    testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
#         with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
            tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.
        
#         with tf.name_scope("Valid"):
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))

                
        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
#             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
         
        sess = tf.Session()
        sess.run(initializer)
        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            # read data
            cur_batch = trainDataStream.nextBatch()
            (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, 
                                 char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, 
                                 sent1_char_length_batch, sent2_char_length_batch,
                                 POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch
            feed_dict = {
                         train_graph.get_truth(): label_id_batch, 
                         train_graph.get_question_lengths(): sent1_length_batch, 
                         train_graph.get_passage_lengths(): sent2_length_batch, 
                         train_graph.get_in_question_words(): word_idx_1_batch, 
                         train_graph.get_in_passage_words(): word_idx_2_batch, 
#                          train_graph.get_question_char_lengths(): sent1_char_length_batch, 
#                          train_graph.get_passage_char_lengths(): sent2_char_length_batch, 
#                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch, 
#                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, 
                         }
            if char_vocab is not None:
                feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

            if POS_vocab is not None:
                feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

            if NER_vocab is not None:
                feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

            _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
            total_loss += loss_value
            
            if step % 100==0: 
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                print()
                # Print status to stdout.
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                total_loss = 0.0

                # Evaluate against the validation set.
                print('Validation Data Eval:')
                accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
                print("Current accuracy is %.2f" % accuracy)
                if accuracy>best_accuracy:
                    best_accuracy = accuracy
                    saver.save(sess, best_path)

    print("Best accuracy on dev set is %.2f" % best_accuracy)
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
                
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab)
        print("Accuracy for test set is %.2f" % accuracy)
Ejemplo n.º 11
0
def main(_):
    print('Configurations:')
    print(FLAGS)
    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    result_dir = '../result'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    path_prefix = log_dir + "/Han.{}".format(FLAGS.suffix)
    save_namespace(FLAGS, path_prefix + ".config.json")
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    label_path = path_prefix + ".label_vocab"

    print('Collect words and labels ...')
    (all_words, all_labels) = collect_vocabs(train_path)
    print('Number of words: {}'.format(len(all_words)))
    print('Number of labels: {}'.format(len(all_labels)))
    label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
    label_vocab.dump_to_txt2(label_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build HanDataStream ... ')
    trainDataStream = HanDataStream(inpath=train_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=True, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)
    devDataStream = HanDataStream(inpath=dev_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=False, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)
    testDataStream = HanDataStream(inpath=test_path, word_vocab=word_vocab,
                                              label_vocab=label_vocab,
                                              isShuffle=False, isLoop=True,
                                              max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))

    with tf.Graph().as_default():
        initializer = tf.contrib.layers.xavier_initializer()
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                                  dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                                  lambda_l2=FLAGS.lambda_l2,
                                                  context_lstm_dim=FLAGS.context_lstm_dim,
                                                  is_training=True, batch_size = FLAGS.batch_size)
            tf.summary.scalar("Training Loss", train_graph.loss)  # Add a scalar summary for the snapshot loss.
        print("Train Graph Build")
        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = HanModelGraph(num_classes=num_classes, word_vocab=word_vocab,
                                                  dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate,
                                                  lambda_l2=FLAGS.lambda_l2,
                                                  context_lstm_dim=FLAGS.context_lstm_dim,
                                                  is_training=False, batch_size = 1)
        print ("dev Graph Build")
        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        output_res_file = open(result_dir + '/' + FLAGS.suffix, 'wt')
        output_res_file.write(str(FLAGS))
        with tf.Session() as sess:
            sess.run(initializer)
            train_size = trainDataStream.get_num_instance()
            max_steps = (train_size * FLAGS.max_epochs) // FLAGS.batch_size
            epoch_size = max_steps // (FLAGS.max_epochs)  # + 1

            total_loss = 0.0
            start_time = time.time()
            best_accuracy = 0
            for step in range(max_steps):
                # read data
                # _truth = []
                # _sents_length = []
                # _in_text_words = []
                # for i in range(FLAGS.batch_size):
                #     cur_instance, instance_index = trainDataStream.nextInstance ()
                #     (label,text,label_id, word_idx, sents_length) = cur_instance
                #
                #     _truth.append(label_id)
                #     _sents_length.append(sents_length)
                #     _in_text_words.append(word_idx)
                #
                # feed_dict = {
                #     train_graph.truth: np.array(_truth),
                #     train_graph.sents_length: tuple(_sents_length),
                #     train_graph.in_text_words: tuple(_in_text_words),
                # }

                feed_dict = get_feed_dict(data_stream=trainDataStream, graph=train_graph, batch_size=FLAGS.batch_size, is_testing=False)

                _, loss_value, _score = sess.run([train_graph.train_op, train_graph.loss
                                                     , train_graph.batch_class_scores],
                                                 feed_dict=feed_dict)
                total_loss += loss_value

                if step % 100 == 0:
                    print('{} '.format(step), end="")
                    sys.stdout.flush()
                if (step + 1) % epoch_size == 0 or (step + 1) == max_steps:
                    # print(total_loss)
                    duration = time.time() - start_time
                    start_time = time.time()
                    print(duration, step, "Loss: ", total_loss)
                    output_res_file.write('\nStep %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                    total_loss = 0.0
                    # Evaluate against the validation set.
                    output_res_file.write('valid- ')
                    dev_accuracy = evaluate(devDataStream, valid_graph, sess)
                    output_res_file.write("%.2f\n" % dev_accuracy)
                    print("Current dev accuracy is %.2f" % dev_accuracy)
                    if dev_accuracy > best_accuracy:
                        best_accuracy = dev_accuracy
                        saver.save(sess, best_path)
                    output_res_file.write('test- ')
                    test_accuracy = evaluate(testDataStream, valid_graph, sess)
                    print("Current test accuracy is %.2f" % test_accuracy)
                    output_res_file.write("%.2f\n" % test_accuracy)

    output_res_file.close()
    sys.stdout.flush()
Ejemplo n.º 12
0
    ### print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)

    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu
    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load vocabs
    print('Loading vocabs.')
    word_vocab = char_vocab = POS_vocab = NER_vocab = None
    if FLAGS.with_word:
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    if FLAGS.with_NER:
        NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    if FLAGS.with_template:
        template_vocab = Vocab(model_prefix + ".template_vocab", fileformat='txt2')
        print('template_vocab: {}'.format(template_vocab.word_vecs.shape))

Ejemplo n.º 13
0
    out_path = args.out_path

    #print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = MHQA_trainer.enrich_options(FLAGS)

    if FLAGS.max_passage_size < 3000:
        FLAGS.max_passage_size = 3000
    print('Maximal passage size {}'.format(FLAGS.max_passage_size))

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(word_vec_path, fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    char_vocab = None
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

    subset_ids = json.load(codecs.open('/u/nalln478/ws/exp.multihop_qa/data.wikihop/distance_subset.json', 'rU', 'utf-8'))
    subset_ids = set(subset_ids)

    print('Loading test set from {}.'.format(in_path))
    testset, test_filtered, _ = MHQA_data_stream.read_data_file(in_path, FLAGS, subset_ids=subset_ids)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    testDataStream = MHQA_data_stream.DataStream(testset, word_vocab, char_vocab, options=FLAGS,
Ejemplo n.º 14
0
        reasonet_lambda = 10
        reasonet_terminate_mode = 'original'
        reasonet_keep_first = True
        reasonet_logit_combine = 'sum'
        if reasonet_training:
            reasonet_steps = FLAGS.reasonet_steps
            reasonet_hidden_dim = FLAGS.reasonet_hidden_dim
            reasonet_lambda = FLAGS.reasonet_lambda
            reasonet_terminate_mode = FLAGS.reasonet_terminate_mode
            reasonet_keep_first = FLAGS.reasonet_keep_first
            reasonet_logit_combine = FLAGS.reasonet_logit_combine

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(word_vec_path,
                       fileformat='txt3',
                       tolower=FLAGS.use_lower_letter)
    label_vocab = Vocab(model_prefix + ".label_vocab",
                        fileformat='txt2',
                        tolower=FLAGS.use_lower_letter)
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    if with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab",
                          fileformat='txt2',
                          tolower=FLAGS.use_lower_letter)
Ejemplo n.º 15
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    init_model_prefix = FLAGS.init_model  # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train"
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
        if FLAGS.max_answer_len > train_ans_len:
            FLAGS.max_answer_len = train_ans_len
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(init_model_prefix + ".best.model.index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(init_model_prefix + ".char_vocab",
                               fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(init_model_prefix + ".POS_vocab",
                              fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(init_model_prefix + ".NER_vocab",
                              fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="rl_train_for_phrase")

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="decode")

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + init_model_prefix + ".best.model")
            saver.restore(sess, init_model_prefix + ".best.model")
            print("DONE!")
        sys.stdout.flush()

        # for first-time rl training, we get the current BLEU score
        print("First-time rl training, get the current BLEU score on dev")
        sys.stdout.flush()
        best_bleu = evaluate(sess,
                             valid_graph,
                             devDataStream,
                             word_vocab,
                             options=FLAGS)
        print('First-time bleu = %.4f' % best_bleu)
        log_file.write('First-time bleu = %.4f\n' % best_bleu)

        print('Start the training loop.')
        sys.stdout.flush()
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.with_baseline:
                # greedy search
                (greedy_sentences, _, _,
                 _) = NP2P_beam_decoder.search(sess,
                                               valid_graph,
                                               word_vocab,
                                               cur_batch,
                                               FLAGS,
                                               decode_mode="greedy")

            if FLAGS.with_target_lattice:
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx, sampled_generator_output_idx
                 ) = cur_batch.sample_a_partition()
            else:
                # multinomial sampling
                (sampled_sentences, sampled_prediction_lengths,
                 sampled_generator_input_idx,
                 sampled_generator_output_idx) = NP2P_beam_decoder.search(
                     sess,
                     valid_graph,
                     word_vocab,
                     cur_batch,
                     FLAGS,
                     decode_mode="multinomial")
            # calculate rewards
            rewards = []
            for i in xrange(cur_batch.batch_size):
                #                 print(sampled_sentences[i])
                #                 print(sampled_generator_input_idx[i])
                #                 print(sampled_generator_output_idx[i])
                cur_toks = cur_batch.instances[i][1].tokText.split()
                #                 r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3)
                r = 1.0
                b = 0.0
                if FLAGS.with_baseline:
                    b = sentence_bleu([cur_toks],
                                      greedy_sentences[i].split(),
                                      smoothing_function=cc.method3)


#                 r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]])
#                 b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]])
                rewards.append(1.0 * (r - b))
            rewards = np.array(rewards, dtype=np.float32)
            #             sys.exit(-1)

            # update parameters
            feed_dict = train_graph.run_encoder(sess,
                                                cur_batch,
                                                FLAGS,
                                                only_feed_dict=True)
            feed_dict[train_graph.reward] = rewards
            feed_dict[
                train_graph.gen_input_words] = sampled_generator_input_idx
            feed_dict[
                train_graph.in_answer_words] = sampled_generator_output_idx
            feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths
            (_,
             loss_value) = sess.run([train_graph.train_op, train_graph.loss],
                                    feed_dict)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                dev_bleu = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    word_vocab,
                                    options=FLAGS)
                print('Dev bleu = %.4f' % dev_bleu)
                log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                log_file.flush()
                if best_bleu < dev_bleu:
                    print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                          format(best_bleu, dev_bleu))
                    best_bleu = dev_bleu
                    saver.save(sess, best_path)  # TODO: save model
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()
                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
Ejemplo n.º 16
0
def main(_):


    #for x in range (100):
    #    Generate_random_initialization()
    #    print (FLAGS.is_aggregation_lstm, FLAGS.context_lstm_dim, FLAGS.context_layer_num, FLAGS. aggregation_lstm_dim, FLAGS.aggregation_layer_num, FLAGS.max_window_size, FLAGS.MP_dim)

    print('Configurations:')
    #print(FLAGS)


    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    if os.path.exists(best_path):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)
                                    
    devDataStream = SentenceMatchDataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True,
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)

    testDataStream = SentenceMatchDataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True,
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length,
                                              is_as=FLAGS.is_answer_selection)

    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None
    output_res_index = 1
    while True:
        Generate_random_initialization()
        st_cuda = ''
        if FLAGS.is_server == True:
            st_cuda = str(os.environ['CUDA_VISIBLE_DEVICES']) + '.'
        output_res_file = open('../result/' + st_cuda + str(output_res_index), 'wt')
        output_res_index += 1
        output_res_file.write(str(FLAGS) + '\n\n')
        stt = str (FLAGS)
        best_accuracy = 0.0
        init_scale = 0.01
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(-init_scale, init_scale)
    #         with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
                                                      dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                                                      lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
                                                      aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim,
                                                      context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
                                                      fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway,
                                                      word_level_MP_dim=FLAGS.word_level_MP_dim,
                                                      with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                                                      highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
                                                      lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                                                      with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                                                      with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
                                                      with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                                                      with_bilinear_att=(FLAGS.attention_type)
                                                      , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
                                                      with_aggregation_attention=not FLAGS.wo_agg_self_att,
                                                      is_answer_selection= FLAGS.is_answer_selection,
                                                      is_shared_attention=FLAGS.is_shared_attention,
                                                      modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm
                                                      , max_window_size=FLAGS.max_window_size
                                                      , prediction_mode=FLAGS.prediction_mode,
                                                      context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
                                                      is_aggregation_siamese=FLAGS.is_aggregation_siamese
                                                      , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention)
                tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.

    #         with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
                                                      dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                                                      lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
                                                      aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim,
                                                      context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
                                                      fix_word_vec=FLAGS.fix_word_vec, with_filter_layer=FLAGS.with_filter_layer, with_input_highway=FLAGS.with_highway,
                                                      word_level_MP_dim=FLAGS.word_level_MP_dim,
                                                      with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                                                      highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
                                                      lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                                                      with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                                                      with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
                                                      with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                                                      with_bilinear_att=(FLAGS.attention_type)
                                                      , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
                                                      with_aggregation_attention=not FLAGS.wo_agg_self_att,
                                                      is_answer_selection= FLAGS.is_answer_selection,
                                                      is_shared_attention=FLAGS.is_shared_attention,
                                                      modify_loss=FLAGS.modify_loss, is_aggregation_lstm=FLAGS.is_aggregation_lstm,
                                                      max_window_size=FLAGS.max_window_size
                                                      , prediction_mode=FLAGS.prediction_mode,
                                                      context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
                                                      is_aggregation_siamese=FLAGS.is_aggregation_siamese
                                                      , unstack_cnn=FLAGS.unstack_cnn,with_context_self_attention=FLAGS.with_context_self_attention)


            initializer = tf.global_variables_initializer()
            vars_ = {}
            #for var in tf.all_variables():
            for var in tf.global_variables():
                if "word_embedding" in var.name: continue
    #             if not var.name.startswith("Model"): continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)

            with tf.Session() as sess:
                sess.run(initializer)
                if has_pre_trained_model:
                    print("Restoring model from " + best_path)
                    saver.restore(sess, best_path)
                    print("DONE!")

                print('Start the training loop.')
                train_size = trainDataStream.get_num_batch()
                max_steps = train_size * FLAGS.max_epochs
                total_loss = 0.0
                start_time = time.time()

                for step in xrange(max_steps):
                    # read data
                    cur_batch, batch_index = trainDataStream.nextBatch()
                    (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch,
                                         char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch,
                                         sent1_char_length_batch, sent2_char_length_batch,
                                         POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch) = cur_batch
                    feed_dict = {
                                 train_graph.get_truth(): label_id_batch,
                                 train_graph.get_question_lengths(): sent1_length_batch,
                                 train_graph.get_passage_lengths(): sent2_length_batch,
                                 train_graph.get_in_question_words(): word_idx_1_batch,
                                 train_graph.get_in_passage_words(): word_idx_2_batch,
        #                          train_graph.get_question_char_lengths(): sent1_char_length_batch,
        #                          train_graph.get_passage_char_lengths(): sent2_char_length_batch,
        #                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch,
        #                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch,
                                 }
                    if char_vocab is not None:
                        feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                        feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                        feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                        feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

                    if POS_vocab is not None:
                        feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                        feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

                    if NER_vocab is not None:
                        feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                        feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

                    if FLAGS.is_answer_selection == True:
                        feed_dict[train_graph.get_question_count()] = trainDataStream.question_count(batch_index)
                        feed_dict[train_graph.get_answer_count()] = trainDataStream.answer_count(batch_index)

                    _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
                    total_loss += loss_value
                    if FLAGS.is_answer_selection == True and FLAGS.is_server == False:
                        print ("q: {} a: {} loss_value: {}".format(trainDataStream.question_count(batch_index)
                                                   ,trainDataStream.answer_count(batch_index), loss_value))

                    if step % 100==0:
                        print('{} '.format(step), end="")
                        sys.stdout.flush()

                    # Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                        #print(total_loss)
                        # Print status to stdout.
                        duration = time.time() - start_time
                        start_time = time.time()
                        output_res_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                        total_loss = 0.0


                        #Evaluate against the validation set.
                        output_res_file.write('valid- ')
                        my_map, my_mrr = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab,
                                            POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                        output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr))
                        #print ("dev map: {}".format(my_map))
                        #print("Current accuracy is %.2f" % accuracy)

                        #accuracy = my_map
                        #if accuracy>best_accuracy:
                        #    best_accuracy = accuracy
                        #    saver.save(sess, best_path)

                        # Evaluate against the test set.
                        output_res_file.write ('test- ')
                        my_map, my_mrr = evaluate(testDataStream, valid_graph, sess, char_vocab=char_vocab,
                                 POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                        output_res_file.write("map: '{}', mrr: '{}\n\n".format(my_map, my_mrr))
                        if FLAGS.is_server == False:
                            print ("test map: {}".format(my_map))

                        #Evaluate against the train set only for final epoch.
                        if (step + 1) == max_steps:
                            output_res_file.write ('train- ')
                            my_map, my_mrr = evaluate(trainDataStream, valid_graph, sess, char_vocab=char_vocab,
                                POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab)
                            output_res_file.write("map: '{}', mrr: '{}'\n".format(my_map, my_mrr))

        # print("Best accuracy on dev set is %.2f" % best_accuracy)
        # # decoding
        # print('Decoding on the test set:')
        # init_scale = 0.01
        # with tf.Graph().as_default():
        #     initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        #     with tf.variable_scope("Model", reuse=False, initializer=initializer):
        #         valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab,
        #              dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
        #              lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim,
        #              aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim,
        #              context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num,
        #              fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
        #              word_level_MP_dim=FLAGS.word_level_MP_dim,
        #              with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
        #              highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition,
        #              lex_decompsition_dim=FLAGS.lex_decompsition_dim,
        #              with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
        #              with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
        #              with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
        #                                               with_bilinear_att=(not FLAGS.wo_bilinear_att)
        #                                               , type1=FLAGS.type1, type2 = FLAGS.type2, type3=FLAGS.type3,
        #                                               with_aggregation_attention=not FLAGS.wo_agg_self_att,
        #                                               is_answer_selection= FLAGS.is_answer_selection,
        #                                               is_shared_attention=FLAGS.is_shared_attention,
        #                                               modify_loss=FLAGS.modify_loss,is_aggregation_lstm=FLAGS.is_aggregation_lstm,
        #                                               max_window_size=FLAGS.max_window_size,
        #                                               prediction_mode=FLAGS.prediction_mode,
        #                                               context_lstm_dropout=not FLAGS.wo_lstm_drop_out,
        #                                              is_aggregation_siamese=FLAGS.is_aggregation_siamese)
        #
        #     vars_ = {}
        #     for var in tf.global_variables():
        #         if "word_embedding" in var.name: continue
        #         if not var.name.startswith("Model"): continue
        #         vars_[var.name.split(":")[0]] = var
        #     saver = tf.train.Saver(vars_)
        #
        #     sess = tf.Session()
        #     sess.run(tf.global_variables_initializer())
        #     step = 0
        #     saver.restore(sess, best_path)
        #
        #     accuracy, mrr = evaluate(testDataStream, valid_graph, sess,char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab
        #                         , mode='trec')
        #     output_res_file.write("map for test set is %.2f\n" % accuracy)
        output_res_file.close()
Ejemplo n.º 17
0
                        help='The path to the output file.')
    parser.add_argument('--word_vec_path',
                        type=str,
                        help='word embedding file for the input file.')

    args, unparsed = parser.parse_known_args()

    # load the configuration file
    tf.logging.info('Loading configurations.')
    options = namespace_utils.load_namespace(args.model_prefix +
                                             "KEIM.snli.config.json")
    if args.word_vec_path is None: args.word_vec_path = options.word_vec_path

    # load vocabs
    tf.logging.info('Loading vocabs.')
    word_vocab = Vocab(args.word_vec_path, fileformat='txt3')
    tf.logging.info('word_vocab: {}'.format(word_vocab.word_vecs.shape))

    lemma_vocab = Vocab(options.lemma_vec_path, fileformat='txt3')
    tf.logging.info('lemma_vocab: {}'.format(lemma_vocab.word_vecs.shape))

    char_vocab = None
    if options.with_char:
        char_vocab = Vocab(args.model_prefix + ".char_vocab",
                           fileformat='txt2')
        tf.logging.info('char_vocab: {}'.format(char_vocab.word_vecs.shape))

    tf.logging.info('Build SentenceMatchDataStream ... ')
    testDataStream = DataStream(args.in_path,
                                word_vocab=word_vocab,
                                char_vocab=char_vocab,
Ejemplo n.º 18
0
    cache_size = args.cache_size
    use_dep = args.decode

    oracle.utils.pushidx_feat_num = (1 + args.cache_size) * 5

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = char_vocab = POS_vocab = NER_vocab = None
    word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    FLAGS.feat_num = 72 + args.cache_size * 5
    action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2')
    print('action_vocab: {}'.format(action_vocab.word_vecs.shape))
    feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2')
    print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape))

    print('Loading test set.')
    if use_dep:
Ejemplo n.º 19
0
    #########################################################################main(FLAGS)

    # DONOTCHANGE: Reserved for nsml

    train_path = DATASET_PATH
    log_dir = config.model_dir

    char_vocab = None
    # if os.path.exists(best_path + ".index"):
    if config.mode == 'train':
        print('Collecting words, chars and labels ...')
        # (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path)
        (all_words, all_chars, all_labels, all_POSs,
         all_NERs) = collect_vocabs_kin(train_path)
        print('Number of words: {}'.format(len(all_words)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels, dim=2)
        # label_vocab.dump_to_txt2(label_path)
        word_vocab = Vocab(fileformat='voc',
                           voc=all_words,
                           dim=config.word_emb_dim)

        if config.with_char:
            print('Number of chars: {}'.format(len(all_chars)))
            char_vocab = Vocab(fileformat='voc',
                               voc=all_chars,
                               dim=config.char_emb_dim)
            # char_vocab.dump_to_txt2(char_path)
    else:
        print('test seq ')
        word_vocab = []
        label_vocab = []
Ejemplo n.º 20
0
def main(_):
    print('Configurations:')
    print(FLAGS)  # 打印各个参数

    root_path = FLAGS.root_path
    train_path = root_path + FLAGS.train_path
    dev_path = root_path + FLAGS.dev_path
    test_path = root_path + FLAGS.test_path
    word_vec_path = root_path + FLAGS.word_vec_path
    model_dir = root_path + FLAGS.model_dir

    if tf.gfile.Exists(model_dir + '/mnist_with_summaries'):
        print("delete summaries")
        tf.gfile.DeleteRecursively(model_dir + '/mnist_with_summaries')

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    path_prefix = model_dir + "/SentenceMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")  # 保存参数

    best_path = path_prefix + '.best.model'
    label_path = path_prefix + ".label_vocab"
    has_pre_trained_model = False
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        print("-------has_pre_trained_model--------")
        print(ckpt.model_checkpoint_path)
        has_pre_trained_model = True

    ############# build vocabs#################
    print('Collect words, chars and labels ...')
    (all_words, all_labels) = collect_vocabs(train_path)
    print('Number of words: {}'.format(len(all_words)))
    print('Number of labels: {}'.format(len(all_labels)))

    word_vocab = Vocab(pattern='word')  # 定义一个类
    word_vocab.patternWord(word_vec_path, model_dir)
    label_vocab = Vocab(pattern="label")
    label_vocab.patternLabel(all_labels, label_path)

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = len(all_labels)

    if FLAGS.wo_char: char_vocab = None
    #####  Build SentenceMatchDataStream  ################
    print('Build SentenceMatchDataStream ... ')
    trainDataStream = SentenceMatchDataStream(
        train_path,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=True,
        isLoop=True,
        isSort=False,
        max_sent_length=FLAGS.max_sent_length)

    devDataStream = SentenceMatchDataStream(
        dev_path,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=False,
        max_sent_length=FLAGS.max_sent_length)

    testDataStream = SentenceMatchDataStream(
        test_path,
        word_vocab=word_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=False,
        max_sent_length=FLAGS.max_sent_length)

    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))

    sys.stdout.flush()

    best_accuracy = 0.0
    init_scale = 0.01
    g_2 = tf.Graph()
    with g_2.as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = SentenceMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                with_word=True,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=True,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_filter_layer=FLAGS.with_filter_layer,
                with_highway=FLAGS.with_highway,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                with_lex_decomposition=FLAGS.with_lex_decomposition,
                lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                with_left_match=(not FLAGS.wo_left_match),
                with_right_match=(not FLAGS.wo_right_match),
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
            tf.summary.scalar("Training Loss", train_graph.get_loss())

        with tf.variable_scope("Model", reuse=True, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                with_word=True,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=False,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_filter_layer=FLAGS.with_filter_layer,
                with_highway=FLAGS.with_highway,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                with_lex_decomposition=FLAGS.with_lex_decomposition,
                lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                with_left_match=(not FLAGS.wo_left_match),
                with_right_match=(not FLAGS.wo_right_match),
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match))

        initializer = tf.global_variables_initializer()
        saver = tf.train.Saver()
        # vars_ = {}
        # for var in tf.global_variables():
        #     if "word_embedding" in var.name: continue
        #     vars_[var.name.split(":")[0]] = var
        # saver = tf.train.Saver(vars_)

        sess = tf.Session()

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            model_dir + '/mnist_with_summaries/train', sess.graph)
        sess.run(initializer)

        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in iter(range(max_steps)):
            cur_batch = trainDataStream.nextBatch()
            (label_id_batch, word_idx_1_batch, word_idx_2_batch,
             sent1_length_batch, sent2_length_batch) = cur_batch
            feed_dict = {
                train_graph.get_truth(): label_id_batch,
                train_graph.get_question_lengths(): sent1_length_batch,
                train_graph.get_passage_lengths(): sent2_length_batch,
                train_graph.get_in_question_words(): word_idx_1_batch,
                train_graph.get_in_passage_words(): word_idx_2_batch,
            }

            # in_question_repres,in_ques=sess.run([train_graph.in_question_repres,train_graph.in_ques],feed_dict=feed_dict)
            # print(in_question_repres,in_ques)
            # break

            _, loss_value, summary = sess.run(
                [train_graph.get_train_op(),
                 train_graph.get_loss(), merged],
                feed_dict=feed_dict)
            total_loss += loss_value

            if step % 5000 == 0:
                # train_writer.add_summary(summary, step)
                print("step:", step, "loss:", loss_value)

            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                total_loss = 0.0

                print('Validation Data Eval:')
                accuracy = evaluate(devDataStream, valid_graph, sess)
                print("Current accuracy is %.2f" % accuracy)
                if accuracy >= best_accuracy:
                    print('Saving model since it\'s the best so far')
                    best_accuracy = accuracy
                    saver.save(sess, best_path)
            sys.stdout.flush()

    print("Best accuracy on dev set is %.2f" % best_accuracy)
Ejemplo n.º 21
0
def main(_):
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading data.')
    FLAGS.num_relations = 2
    trainset = G2S_data_stream.read_bionlp_file(FLAGS.train_path, FLAGS.train_dep_path, FLAGS)
    if FLAGS.dev_gen == 'shuffle':
        random.shuffle(trainset)
    elif FLAGS.dev_gen == 'last':
        trainset.reverse()
    N = int(len(trainset)*FLAGS.dev_percent)
    devset = trainset[:N]
    trainset = trainset[N:]

    print('Number of training samples: {}'.format(len(trainset)))
    print('Number of dev samples: {}'.format(len(devset)))
    print('Number of relations: {}'.format(FLAGS.num_relations))

    word_vocab = None
    char_vocab = None
    POS_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2')
        print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        all_words = set()
        all_chars = set()
        all_poses = set()
        all_edgelabels = set()
        G2S_data_stream.collect_vocabs(trainset, all_words, all_chars, all_poses, all_edgelabels)
        G2S_data_stream.collect_vocabs(devset, all_words, all_chars, all_poses, all_edgelabels)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of chars: {}'.format(len(all_chars)))
        print('Number of poses: {}'.format(len(all_poses)))
        print('Number of edgelabels: {}'.format(len(all_edgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=all_chars, dim=FLAGS.char_dim, fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=all_poses, dim=FLAGS.POS_dim, fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        edgelabel_vocab = Vocab(voc=all_edgelabels, dim=FLAGS.edgelabel_dim, fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(FLAGS, trainset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
            isShuffle=True, isLoop=True, isSort=True, is_training=True)

    devDataStream = G2S_data_stream.G2SDataStream(FLAGS, devset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
            isShuffle=False, isLoop=False, isSort=True)
    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    sys.stdout.flush()

    FLAGS.trn_bch_num = trainDataStream.get_num_batch()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
                                         FLAGS, mode='train')

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
                                         FLAGS, mode='evaluate')

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if abs(best_accu) < 1e-5:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess, valid_graph, devDataStream, FLAGS)['dev_f1']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        last_step = 0
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            _, _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True)
            total_loss += cur_loss

            if step % 100==0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss/(step-last_step), duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss/(step-last_step), duration))
                sys.stdout.flush()
                log_file.flush()
                last_step = step
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess, valid_graph, devDataStream, FLAGS)
                dev_loss = res_dict['dev_loss']
                dev_accu = res_dict['dev_f1']
                dev_precision = res_dict['dev_precision']
                dev_recall = res_dict['dev_recall']
                print('Dev loss = %.4f' % dev_loss)
                log_file.write('Dev loss = %.4f\n' % dev_loss)
                print('Dev F1 = %.4f, P = %.4f, R = %.4f' % (dev_accu, dev_precision, dev_recall))
                log_file.write('Dev F1 = %.4f, P =  %.4f, R = %.4f\n' % (dev_accu, dev_precision, dev_recall))
                log_file.flush()
                if best_accu < dev_accu:
                    print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
                    saver.save(sess, best_path)
                    best_accu = dev_accu
                    FLAGS.best_accu = dev_accu
                    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()
                start_time = time.time()

    log_file.close()
Ejemplo n.º 22
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    if FLAGS.train == "sick":
        train_path = FLAGS.SICK_train_path
        dev_path = FLAGS.SICK_dev_path
    if FLAGS.test == "sick":
        test_path = FLAGS.SICK_test_path
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    path_prefix = log_dir + "/SentenceMatch.{}".format(FLAGS.suffix)
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    parser,image_feats = None, None
    if FLAGS.with_dep:
        parser=Parser('snli')
    if FLAGS.with_image:
        image_feats=ImageFeatures()
    word_vocab = Vocab(word_vec_path, fileformat='txt3', parser=parser, beginning=FLAGS.beginning) #fileformat='txt3'
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    DEP_path = path_prefix + ".DEP_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None
    DEP_vocab = None
    print('has pretrained model: ', os.path.exists(best_path))
    print('best_path: ' + best_path)
    if os.path.exists(best_path + '.meta'):
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2')
        char_vocab = Vocab(char_path, fileformat='txt2')
        if FLAGS.with_POS: POS_vocab = Vocab(POS_path, fileformat='txt2')
        if FLAGS.with_NER: NER_vocab = Vocab(NER_path, fileformat='txt2')
    else:
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs, all_NERs) = collect_vocabs(train_path, with_POS=FLAGS.with_POS, with_NER=FLAGS.with_NER)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        label_vocab = Vocab(fileformat='voc', voc=all_labels,dim=2)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc', voc=all_chars,dim=FLAGS.char_emb_dim, beginning=FLAGS.beginning)
        char_vocab.dump_to_txt2(char_path)
        
        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc', voc=all_POSs,dim=FLAGS.POS_dim)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc', voc=all_NERs,dim=FLAGS.NER_dim)
            NER_vocab.dump_to_txt2(NER_path)
            

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build DataStream ... ')
    print('Reading trainDataStream')
    if not FLAGS.decoding_only:
        trainDataStream = DataStream(train_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=True, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                              with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))
    
        print('Reading devDataStream')
        devDataStream = DataStream(dev_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                              max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                              with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))

    print('Reading testDataStream')
    testDataStream = DataStream(test_path, word_vocab=word_vocab, char_vocab=char_vocab, 
                                              POS_vocab=POS_vocab, NER_vocab=NER_vocab, label_vocab=label_vocab, 
                                              batch_size=FLAGS.batch_size, isShuffle=False, isLoop=True, isSort=True, 
                                                  max_char_per_word=FLAGS.max_char_per_word, max_sent_length=FLAGS.max_sent_length, 
                                                  with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_feats=image_feats, sick_data=(FLAGS.test == "sick"))
    print('save cache file')
    #word_vocab.parser.save_cache()
    #image_feats.save_feat()
    if not FLAGS.decoding_only:
        print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
        print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
        print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
        print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    print('Number of instances in testDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(testDataStream.get_num_batch()))
    
    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    
    if not FLAGS.decoding_only:
        with tf.Graph().as_default():
            initializer = tf.random_uniform_initializer(-init_scale, init_scale)
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                    dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=True, MP_dim=FLAGS.MP_dim, 
                    context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                    fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,with_lex_decomposition=FLAGS.with_lex_decomposition, 
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                    with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                    with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                    with_img_full_match=FLAGS.with_img_full_match, with_img_maxpool_match=FLAGS.with_img_full_match, 
                    with_img_attentive_match=FLAGS.with_img_attentive_match, image_context_layer=FLAGS.image_context_layer, 
                    with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)
                tf.summary.scalar("Training Loss", train_graph.get_loss()) # Add a scalar summary for the snapshot loss.
        
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                    dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                    context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                    fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                    lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                    with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                    with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                    with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                    with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                    with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, 
                    with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, 
                    with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)

                
            initializer = tf.global_variables_initializer()
            vars_ = {}
            for var in tf.all_variables():
                if "word_embedding" in var.name: continue
                vars_[var.name.split(":")[0]] = var
            saver = tf.train.Saver(vars_)
         
            sess = tf.Session()
            sess.run(initializer) #, feed_dict={valid_graph.emb_init: word_vocab.word_vecs, train_graph.emb_init: word_vocab.word_vecs})
            if has_pre_trained_model:
                print("Restoring model from " + best_path)
                saver.restore(sess, best_path)
                print("DONE!")
                #if best_path.startswith('bimpm_baseline'):
                #best_path = best_path + '_sick' 

            print('Start the training loop.')
            train_size = trainDataStream.get_num_batch()
            max_steps = train_size * FLAGS.max_epochs
            total_loss = 0.0
            start_time = time.time()
            for step in xrange(max_steps):
                # read data
                cur_batch = trainDataStream.nextBatch()
                (label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch, 
                                 char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch, 
                                 sent1_char_length_batch, sent2_char_length_batch,
                                 POS_idx_1_batch, POS_idx_2_batch, NER_idx_1_batch, NER_idx_2_batch, 
                                 dependency1_batch, dependency2_batch, dep_con1_batch, dep_con2_batch, img_feats_batch, img_id_batch) = cur_batch
                feed_dict = {
                         train_graph.get_truth(): label_id_batch, 
                         train_graph.get_question_lengths(): sent1_length_batch, 
                         train_graph.get_passage_lengths(): sent2_length_batch, 
                         train_graph.get_in_question_words(): word_idx_1_batch, 
                         train_graph.get_in_passage_words(): word_idx_2_batch,
                         #train_graph.get_emb_init(): word_vocab.word_vecs,
                         #train_graph.get_in_question_dependency(): dependency1_batch,
                         #train_graph.get_in_passage_dependency(): dependency2_batch,
#                          train_graph.get_question_char_lengths(): sent1_char_length_batch, 
#                          train_graph.get_passage_char_lengths(): sent2_char_length_batch, 
#                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch, 
#                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch, 
                         }
                if FLAGS.with_dep:
                    feed_dict[train_graph.get_in_question_dependency()] = dependency1_batch
                    feed_dict[train_graph.get_in_passage_dependency()] = dependency2_batch
                    feed_dict[train_graph.get_in_question_dep_con()] = dep_con1_batch
                    feed_dict[train_graph.get_in_passage_dep_con()] = dep_con2_batch

                if FLAGS.with_image:
                    feed_dict[train_graph.get_image_feats()] = img_feats_batch

                if char_vocab is not None:
                    feed_dict[train_graph.get_question_char_lengths()] = sent1_char_length_batch
                    feed_dict[train_graph.get_passage_char_lengths()] = sent2_char_length_batch
                    feed_dict[train_graph.get_in_question_chars()] = char_matrix_idx_1_batch
                    feed_dict[train_graph.get_in_passage_chars()] = char_matrix_idx_2_batch

                if POS_vocab is not None:
                    feed_dict[train_graph.get_in_question_poss()] = POS_idx_1_batch
                    feed_dict[train_graph.get_in_passage_poss()] = POS_idx_2_batch

                if NER_vocab is not None:
                    feed_dict[train_graph.get_in_question_ners()] = NER_idx_1_batch
                    feed_dict[train_graph.get_in_passage_ners()] = NER_idx_2_batch

                _, loss_value = sess.run([train_graph.get_train_op(), train_graph.get_loss()], feed_dict=feed_dict)
                total_loss += loss_value
            
                if step % 100==0: 
                    print('{} '.format(step), end="")
                    sys.stdout.flush()

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                    print()
                    # Print status to stdout.
                    duration = time.time() - start_time
                    start_time = time.time()
                    print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                    total_loss = 0.0

                    # Evaluate against the validation set.
                    print('Validation Data Eval:')
                    accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, word_vocab=word_vocab)
                    print("Current accuracy on dev is %.2f" % accuracy)
                
                    #accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
                    #print("Current accuracy on train is %.2f" % accuracy_train)
                    if accuracy>best_accuracy:
                        best_accuracy = accuracy
                        saver.save(sess, best_path)
        print("Best accuracy on dev set is %.2f" % best_accuracy)
    
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = ModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, 
                 dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, optimize_type=FLAGS.optimize_type,
                 lambda_l2=FLAGS.lambda_l2, char_lstm_dim=FLAGS.char_lstm_dim, context_lstm_dim=FLAGS.context_lstm_dim, 
                 aggregation_lstm_dim=FLAGS.aggregation_lstm_dim, is_training=False, MP_dim=FLAGS.MP_dim, 
                 context_layer_num=FLAGS.context_layer_num, aggregation_layer_num=FLAGS.aggregation_layer_num, 
                 fix_word_vec=FLAGS.fix_word_vec,with_filter_layer=FLAGS.with_filter_layer, with_highway=FLAGS.with_highway,
                 word_level_MP_dim=FLAGS.word_level_MP_dim,
                 with_match_highway=FLAGS.with_match_highway, with_aggregation_highway=FLAGS.with_aggregation_highway,
                 highway_layer_num=FLAGS.highway_layer_num, with_lex_decomposition=FLAGS.with_lex_decomposition, 
                 lex_decompsition_dim=FLAGS.lex_decompsition_dim,
                 with_left_match=(not FLAGS.wo_left_match), with_right_match=(not FLAGS.wo_right_match),
                 with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match), 
                 with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match), 
                 with_dep=FLAGS.with_dep, with_image=FLAGS.with_image, image_with_hypothesis_only=FLAGS.image_with_hypothesis_only,
                 with_img_attentive_match=FLAGS.with_img_attentive_match, with_img_full_match=FLAGS.with_img_full_match, 
                 with_img_maxpool_match=FLAGS.with_img_full_match, image_context_layer=FLAGS.image_context_layer, 
                 with_img_max_attentive_match=FLAGS.with_img_max_attentive_match, img_dim=FLAGS.img_dim)
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
                
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())#, feed_dict={valid_graph.emb_init: word_vocab.word_vecs})
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream, valid_graph, sess, outpath=FLAGS.suffix+ FLAGS.train + FLAGS.test + ".result",char_vocab=char_vocab,label_vocab=label_vocab, word_vocab=word_vocab)
        print("Accuracy for test set is %.2f" % accuracy)
        accuracy_train = evaluate(trainDataStream, valid_graph, sess,char_vocab=char_vocab,word_vocab=word_vocab)
        print("Accuracy for train set is %.2f" % accuracy_train)
Ejemplo n.º 23
0
                                             self.options.max_src_len)
        self.sent_inp = padding_utils.pad_2d_vals(ori_batch.sent_inp,
                                                  len(ori_batch.sent_inp),
                                                  self.options.max_answer_len)
        self.sent_out = padding_utils.pad_2d_vals(ori_batch.sent_out,
                                                  len(ori_batch.sent_out),
                                                  self.options.max_answer_len)


if __name__ == "__main__":
    FLAGS = namespace_utils.load_namespace('../config.json')
    print('Collecting vocab')
    allEdgelabels = set([line.strip().split()[0] \
            for line in open('../data/edgelabel_vocab.en', 'rU')])
    edgelabel_vocab = Vocab(voc=allEdgelabels,
                            dim=FLAGS.edgelabel_dim,
                            fileformat='build')
    word_vocab_enc = Vocab('../data/vectors.en.st', fileformat='txt2')
    word_vocab_dec = Vocab('../data/vectors.de.st', fileformat='txt2')
    print('Loading trainset')
    trainset, _, _, _, _ = read_amr_file('../data/newstest2013.tok.json',
                                         FLAGS, word_vocab_enc, word_vocab_dec,
                                         None, edgelabel_vocab)
    print('Build DataStream ... ')
    trainDataStream = G2SDataStream(trainset,
                                    word_vocab_enc,
                                    word_vocab_dec,
                                    None,
                                    edgelabel_vocab,
                                    options=FLAGS,
                                    isShuffle=True,
    model_prefix = args.model_prefix
    in_path = args.in_path
    out_path = args.out_path
    mode = args.mode

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = G2S_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab",
                            fileformat='txt2')
    print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    char_vocab = None
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

    print('Loading test set from {}.'.format(in_path))
    testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
Ejemplo n.º 25
0
    model_prefix = args.model_prefix
    in_path = args.in_path
    cache_size = args.cache_size
    use_dep = args.decode

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = char_vocab = POS_vocab = NER_vocab = None
    word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2')
    print('action_vocab: {}'.format(action_vocab.word_vecs.shape))
    feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2')
    print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape))

    print('Loading test set.')
    if use_dep:
        testset = NP2P_data_stream.read_Testset(in_path, ulfdep=args.ulf)
Ejemplo n.º 26
0
    config_FLAGS.__dict__["in_format"] = 'tsv'
    word_vec_path = config_FLAGS.word_vec_path
    log_dir = config_FLAGS.model_dir
    path_prefix = os.path.join(log_dir,
                               "SentenceMatch.{}".format(config_FLAGS.suffix))

    ent_word_vocab = EntVocab(word_vec_path, fileformat='txt3')
    print("word_vocab shape is {}".format(ent_word_vocab.word_vecs.shape))

    best_path = path_prefix + ".best.model"
    label_path = path_prefix + ".label_vocab"
    print("best_path: {}".format(best_path))

    if os.path.exists(best_path + ".index"):
        print("Loading label vocab")
        label_vocab = EntVocab(label_path, fileformat='txt2')
    else:
        raise Exception("no pretrained model")

    num_classes = label_vocab.size()
    print("Number of labels: {}".format(num_classes))

    global_step = tf.train.get_global_step()

    # define entailment model
    config_FLAGS = namespace_utils.load_namespace(config_path)
    entailment_model = SentenceMatchModelGraph(3,
                                               word_vocab=ent_word_vocab,
                                               is_training=True,
                                               options=config_FLAGS,
                                               global_step=global_step,
Ejemplo n.º 27
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    train_path = FLAGS.train_path
    dev_path = FLAGS.dev_path
    test_path = FLAGS.test_path
    word_vec_path = FLAGS.word_vec_path
    log_dir = FLAGS.model_dir
    tolower = FLAGS.use_lower_letter
    FLAGS.rl_matches = json.loads(FLAGS.rl_matches)

    # if not os.path.exists(log_dir):
    #     os.makedirs(log_dir)

    path_prefix = log_dir + "/TriMatch.{}".format(FLAGS.suffix)

    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    # build vocabs
    word_vocab = Vocab(word_vec_path, fileformat='txt3', tolower=tolower)
    best_path = path_prefix + '.best.model'
    char_path = path_prefix + ".char_vocab"
    label_path = path_prefix + ".label_vocab"
    POS_path = path_prefix + ".POS_vocab"
    NER_path = path_prefix + ".NER_vocab"
    has_pre_trained_model = False
    POS_vocab = None
    NER_vocab = None

    print('best path:', best_path)
    if os.path.exists(best_path +
                      '.data-00000-of-00001') and not (FLAGS.create_new_model):
        print('Using pretrained model')
        has_pre_trained_model = True
        label_vocab = Vocab(label_path, fileformat='txt2', tolower=tolower)
        char_vocab = Vocab(char_path, fileformat='txt2', tolower=tolower)
        if FLAGS.with_POS:
            POS_vocab = Vocab(POS_path, fileformat='txt2', tolower=tolower)
        if FLAGS.with_NER:
            NER_vocab = Vocab(NER_path, fileformat='txt2', tolower=tolower)
    else:
        print('Creating new model')
        print('Collect words, chars and labels ...')
        (all_words, all_chars, all_labels, all_POSs,
         all_NERs) = collect_vocabs(train_path,
                                    with_POS=FLAGS.with_POS,
                                    with_NER=FLAGS.with_NER,
                                    tolower=tolower)
        if FLAGS.use_options:
            all_labels = ['0', '1']
        print('Number of words: {}'.format(len(all_words)))
        print('Number of labels: {}'.format(len(all_labels)))
        # for word in all_labels:
        #     print('label',word)
        # input('check')

        label_vocab = Vocab(fileformat='voc',
                            voc=all_labels,
                            dim=2,
                            tolower=tolower)
        label_vocab.dump_to_txt2(label_path)

        print('Number of chars: {}'.format(len(all_chars)))
        char_vocab = Vocab(fileformat='voc',
                           voc=all_chars,
                           dim=FLAGS.char_emb_dim,
                           tolower=tolower)
        char_vocab.dump_to_txt2(char_path)

        if FLAGS.with_POS:
            print('Number of POSs: {}'.format(len(all_POSs)))
            POS_vocab = Vocab(fileformat='voc',
                              voc=all_POSs,
                              dim=FLAGS.POS_dim,
                              tolower=tolower)
            POS_vocab.dump_to_txt2(POS_path)
        if FLAGS.with_NER:
            print('Number of NERs: {}'.format(len(all_NERs)))
            NER_vocab = Vocab(fileformat='voc',
                              voc=all_NERs,
                              dim=FLAGS.NER_dim,
                              tolower=tolower)
            NER_vocab.dump_to_txt2(NER_path)

    print('all_labels:', label_vocab)
    print('has pretrained model:', has_pre_trained_model)
    # for word in word_vocab.word_vecs:

    print('word_vocab shape is {}'.format(word_vocab.word_vecs.shape))
    print('tag_vocab shape is {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    print('Build TriMatchDataStream ... ')

    gen_concat_mat = False
    gen_split_mat = False
    if FLAGS.matching_option == 7:
        gen_concat_mat = True
        if FLAGS.concat_context:
            gen_split_mat = True
    trainDataStream = TriMatchDataStream(
        train_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=True,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    devDataStream = TriMatchDataStream(
        dev_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    testDataStream = TriMatchDataStream(
        test_path,
        word_vocab=word_vocab,
        char_vocab=char_vocab,
        POS_vocab=POS_vocab,
        NER_vocab=NER_vocab,
        label_vocab=label_vocab,
        batch_size=FLAGS.batch_size,
        isShuffle=False,
        isLoop=True,
        isSort=(not FLAGS.wo_sort_instance_based_on_length),
        max_char_per_word=FLAGS.max_char_per_word,
        max_sent_length=FLAGS.max_sent_length,
        max_hyp_length=FLAGS.max_hyp_length,
        max_choice_length=FLAGS.max_choice_length,
        tolower=tolower,
        gen_concat_mat=gen_concat_mat,
        gen_split_mat=gen_split_mat)

    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    print('Number of batches in testDataStream: {}'.format(
        testDataStream.get_num_batch()))

    sys.stdout.flush()
    if FLAGS.wo_char: char_vocab = None

    best_accuracy = 0.0
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        #         with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None, initializer=initializer):
            train_graph = TriMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                char_vocab=char_vocab,
                POS_vocab=POS_vocab,
                NER_vocab=NER_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                char_lstm_dim=FLAGS.char_lstm_dim,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=True,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_highway=FLAGS.with_highway,
                word_level_MP_dim=FLAGS.word_level_MP_dim,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                match_to_question=FLAGS.match_to_question,
                match_to_passage=FLAGS.match_to_passage,
                match_to_choice=FLAGS.match_to_choice,
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                use_options=FLAGS.use_options,
                num_options=num_options,
                with_no_match=FLAGS.with_no_match,
                verbose=FLAGS.verbose,
                matching_option=FLAGS.matching_option,
                concat_context=FLAGS.concat_context,
                tied_aggre=FLAGS.tied_aggre,
                rl_training_method=FLAGS.rl_training_method,
                rl_matches=FLAGS.rl_matches)

            tf.summary.scalar("Training Loss", train_graph.get_loss()
                              )  # Add a scalar summary for the snapshot loss.
        if FLAGS.verbose:
            valid_graph = train_graph
        else:
            #         with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = TriMatchModelGraph(
                    num_classes,
                    word_vocab=word_vocab,
                    char_vocab=char_vocab,
                    POS_vocab=POS_vocab,
                    NER_vocab=NER_vocab,
                    dropout_rate=FLAGS.dropout_rate,
                    learning_rate=FLAGS.learning_rate,
                    optimize_type=FLAGS.optimize_type,
                    lambda_l2=FLAGS.lambda_l2,
                    char_lstm_dim=FLAGS.char_lstm_dim,
                    context_lstm_dim=FLAGS.context_lstm_dim,
                    aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                    is_training=False,
                    MP_dim=FLAGS.MP_dim,
                    context_layer_num=FLAGS.context_layer_num,
                    aggregation_layer_num=FLAGS.aggregation_layer_num,
                    fix_word_vec=FLAGS.fix_word_vec,
                    with_highway=FLAGS.with_highway,
                    word_level_MP_dim=FLAGS.word_level_MP_dim,
                    with_match_highway=FLAGS.with_match_highway,
                    with_aggregation_highway=FLAGS.with_aggregation_highway,
                    highway_layer_num=FLAGS.highway_layer_num,
                    match_to_question=FLAGS.match_to_question,
                    match_to_passage=FLAGS.match_to_passage,
                    match_to_choice=FLAGS.match_to_choice,
                    with_full_match=(not FLAGS.wo_full_match),
                    with_maxpool_match=(not FLAGS.wo_maxpool_match),
                    with_attentive_match=(not FLAGS.wo_attentive_match),
                    with_max_attentive_match=(
                        not FLAGS.wo_max_attentive_match),
                    use_options=FLAGS.use_options,
                    num_options=num_options,
                    with_no_match=FLAGS.with_no_match,
                    matching_option=FLAGS.matching_option,
                    concat_context=FLAGS.concat_context,
                    tied_aggre=FLAGS.tied_aggre,
                    rl_training_method=FLAGS.rl_training_method,
                    rl_matches=FLAGS.rl_matches)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            # print(var.name,var.get_shape().as_list())
            if "word_embedding" in var.name: continue
            #             if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)
        # input('check')

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(initializer)
        if has_pre_trained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        sub_loss_counter = 0.0
        for step in range(max_steps):
            # read data
            cur_batch = trainDataStream.nextBatch()
            (label_batch, sent1_batch, sent2_batch, sent3_batch,
             label_id_batch, word_idx_1_batch, word_idx_2_batch,
             word_idx_3_batch, char_matrix_idx_1_batch,
             char_matrix_idx_2_batch, char_matrix_idx_3_batch,
             sent1_length_batch, sent2_length_batch, sent3_length_batch,
             sent1_char_length_batch, sent2_char_length_batch,
             sent3_char_length_batch, POS_idx_1_batch, POS_idx_2_batch,
             NER_idx_1_batch, NER_idx_2_batch, concat_mat_batch,
             split_mat_batch_q, split_mat_batch_c) = cur_batch

            # print(label_id_batch)
            if FLAGS.verbose:
                print(label_id_batch)
                print(sent1_length_batch)
                print(sent2_length_batch)
                print(sent3_length_batch)
                # print(word_idx_1_batch)
                # print(word_idx_2_batch)
                # print(word_idx_3_batch)
                # print(sent1_batch)
                # print(sent2_batch)
                # print(sent3_batch)
                print(concat_mat_batch)
                input('check')
            feed_dict = {
                train_graph.get_truth(): label_id_batch,
                train_graph.get_passage_lengths(): sent1_length_batch,
                train_graph.get_question_lengths(): sent2_length_batch,
                train_graph.get_choice_lengths(): sent3_length_batch,
                train_graph.get_in_passage_words(): word_idx_1_batch,
                train_graph.get_in_question_words(): word_idx_2_batch,
                train_graph.get_in_choice_words(): word_idx_3_batch,
                #                          train_graph.get_question_char_lengths(): sent1_char_length_batch,
                #                          train_graph.get_passage_char_lengths(): sent2_char_length_batch,
                #                          train_graph.get_in_question_chars(): char_matrix_idx_1_batch,
                #                          train_graph.get_in_passage_chars(): char_matrix_idx_2_batch,
            }
            if char_vocab is not None:
                feed_dict[train_graph.get_passage_char_lengths(
                )] = sent1_char_length_batch
                feed_dict[train_graph.get_question_char_lengths(
                )] = sent2_char_length_batch
                feed_dict[train_graph.get_choice_char_lengths(
                )] = sent3_char_length_batch
                feed_dict[train_graph.get_in_passage_chars(
                )] = char_matrix_idx_1_batch
                feed_dict[train_graph.get_in_question_chars(
                )] = char_matrix_idx_2_batch
                feed_dict[train_graph.get_in_choice_chars(
                )] = char_matrix_idx_3_batch

            if POS_vocab is not None:
                feed_dict[train_graph.get_in_passage_poss()] = POS_idx_1_batch
                feed_dict[train_graph.get_in_question_poss()] = POS_idx_2_batch

            if NER_vocab is not None:
                feed_dict[train_graph.get_in_passage_ners()] = NER_idx_1_batch
                feed_dict[train_graph.get_in_question_ners()] = NER_idx_2_batch
            if concat_mat_batch is not None:
                feed_dict[train_graph.concat_idx_mat] = concat_mat_batch
            if split_mat_batch_q is not None:
                feed_dict[train_graph.split_idx_mat_q] = split_mat_batch_q
                feed_dict[train_graph.split_idx_mat_c] = split_mat_batch_c

            if FLAGS.verbose:
                return_list = sess.run([
                    train_graph.get_train_op(),
                    train_graph.get_loss(),
                    train_graph.get_predictions(),
                    train_graph.get_prob(), train_graph.all_probs,
                    train_graph.correct
                ] + train_graph.matching_vectors,
                                       feed_dict=feed_dict)
                print(len(return_list))
                _, loss_value, pred, prob, all_probs, correct = return_list[
                    0:6]
                print('pred=', pred)
                print('prob=', prob)
                print('logits=', all_probs)
                print('correct=', correct)
                for val in return_list[6:]:
                    if isinstance(val, list):
                        print('list len ', len(val))
                        for objj in val:
                            print('this shape=', val.shape)
                    print('this shape=', val.shape)
                    # print(val)
                input('check')
            else:
                _, loss_value = sess.run(
                    [train_graph.get_train_op(),
                     train_graph.get_loss()],
                    feed_dict=feed_dict)
            total_loss += loss_value
            sub_loss_counter += loss_value

            if step % int(FLAGS.display_every) == 0:
                print('{},{} '.format(step, sub_loss_counter), end="")
                sys.stdout.flush()
                sub_loss_counter = 0.0

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                # Print status to stdout.
                duration = time.time() - start_time
                start_time = time.time()
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                total_loss = 0.0

                # Evaluate against the validation set.
                print('Validation Data Eval:')
                if FLAGS.predict_val:
                    outpath = path_prefix + '.iter%d' % (step) + '.probs'
                else:
                    outpath = None
                accuracy = evaluate(devDataStream,
                                    valid_graph,
                                    sess,
                                    char_vocab=char_vocab,
                                    POS_vocab=POS_vocab,
                                    NER_vocab=NER_vocab,
                                    use_options=FLAGS.use_options,
                                    outpath=outpath,
                                    mode='prob')
                print("Current accuracy on dev set is %.2f" % accuracy)
                if accuracy >= best_accuracy:
                    best_accuracy = accuracy
                    saver.save(sess, best_path)
                    print('saving the current model.')
                accuracy = evaluate(testDataStream,
                                    valid_graph,
                                    sess,
                                    char_vocab=char_vocab,
                                    POS_vocab=POS_vocab,
                                    NER_vocab=NER_vocab,
                                    use_options=FLAGS.use_options,
                                    outpath=outpath,
                                    mode='prob')
                print("Current accuracy on test set is %.2f" % accuracy)

    print("Best accuracy on dev set is %.2f" % best_accuracy)
    # decoding
    print('Decoding on the test set:')
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = TriMatchModelGraph(
                num_classes,
                word_vocab=word_vocab,
                char_vocab=char_vocab,
                POS_vocab=POS_vocab,
                NER_vocab=NER_vocab,
                dropout_rate=FLAGS.dropout_rate,
                learning_rate=FLAGS.learning_rate,
                optimize_type=FLAGS.optimize_type,
                lambda_l2=FLAGS.lambda_l2,
                char_lstm_dim=FLAGS.char_lstm_dim,
                context_lstm_dim=FLAGS.context_lstm_dim,
                aggregation_lstm_dim=FLAGS.aggregation_lstm_dim,
                is_training=False,
                MP_dim=FLAGS.MP_dim,
                context_layer_num=FLAGS.context_layer_num,
                aggregation_layer_num=FLAGS.aggregation_layer_num,
                fix_word_vec=FLAGS.fix_word_vec,
                with_highway=FLAGS.with_highway,
                word_level_MP_dim=FLAGS.word_level_MP_dim,
                with_match_highway=FLAGS.with_match_highway,
                with_aggregation_highway=FLAGS.with_aggregation_highway,
                highway_layer_num=FLAGS.highway_layer_num,
                match_to_question=FLAGS.match_to_question,
                match_to_passage=FLAGS.match_to_passage,
                match_to_choice=FLAGS.match_to_choice,
                with_full_match=(not FLAGS.wo_full_match),
                with_maxpool_match=(not FLAGS.wo_maxpool_match),
                with_attentive_match=(not FLAGS.wo_attentive_match),
                with_max_attentive_match=(not FLAGS.wo_max_attentive_match),
                use_options=FLAGS.use_options,
                num_options=num_options,
                with_no_match=FLAGS.with_no_match,
                matching_option=FLAGS.matching_option,
                concat_context=FLAGS.concat_context,
                tied_aggre=FLAGS.tied_aggre,
                rl_training_method=FLAGS.rl_training_method,
                rl_matches=FLAGS.rl_matches)
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        step = 0
        saver.restore(sess, best_path)

        accuracy = evaluate(testDataStream,
                            valid_graph,
                            sess,
                            char_vocab=char_vocab,
                            POS_vocab=POS_vocab,
                            NER_vocab=NER_vocab,
                            use_options=FLAGS.use_options)
        print("Accuracy for test set is %.2f" % accuracy)
Ejemplo n.º 28
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    word_vocab_enc = None
    word_vocab_dec = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
        print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape))
        word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
        print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2')
    else:
        print('Collecting vocabs.')
        word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
        word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        allEdgelabels = set([line.strip().split()[0] \
                for line in open(FLAGS.edgelabel_vocab, 'rU')])
        edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab SRC size {}'.format(word_vocab_enc.vocab_size))
    print('word vocab TGT size {}'.format(word_vocab_dec.vocab_size))
    sys.stdout.flush()

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_from_fof(FLAGS.train_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    else:
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_from_fof(FLAGS.test_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    else:
        testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    print('Number of test samples: {}'.format(len(testset)))

    max_node = max(trn_node, tst_node)
    max_in_neigh = max(trn_in_neigh, tst_in_neigh)
    max_out_neigh = max(trn_out_neigh, tst_out_neigh)
    max_sent = max(trn_sent, tst_sent)
    print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len))

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab,
            options=FLAGS, isShuffle=True, isLoop=True, isSort=True)

    devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab,
            options=FLAGS, isShuffle=False, isLoop=False, isSort=True)
    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode)

        assert FLAGS.mode in ('ce_train', 'rl_train', )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab, options=FLAGS, mode=valid_mode)

        initializer = tf.global_variables_initializer()

        for var in tf.trainable_variables():
            print(var)

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                sys.stdout.flush()
                best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                sys.stdout.flush()
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch)
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100==0:
                print('{} '.format(step), end="")
                sys.stdout.flush()


            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or (step != 0 and step%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
def get_test_result(in_p,root_path):
    print('Loading configurations.')
    model_prefix =root_path+"/stsapp/src/logs/SentenceMatch.snli"
    word_vec_path = root_path+"/stsapp/src/data/snli/wordvec.txt"



    in_path = in_p

    out_path =root_path+"/stsapp/src/result.txt"


    print("access decoder")

    options = namespace_utils.load_namespace(model_prefix + ".config.json")

    if word_vec_path is None: word_vec_path = options.word_vec_path


    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(word_vec_path, fileformat='txt3')
    label_vocab = Vocab(model_prefix + ".label_vocab", fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    print('label_vocab: {}'.format(label_vocab.word_vecs.shape))
    num_classes = label_vocab.size()

    if options.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    
    print('Build SentenceMatchDataStream ... ')
    testDataStream = SentenceMatchDataStream(in_path, word_vocab=word_vocab, char_vocab=char_vocab,
                                            label_vocab=label_vocab,
                                            isShuffle=False, isLoop=True, isSort=True, options=options)
    print('Number of instances in devDataStream: {}'.format(testDataStream.get_num_instance()))
    print('Number of batches in devDataStream: {}'.format(testDataStream.get_num_batch()))
    sys.stdout.flush()

    best_path = model_prefix + ".best.model"
    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        global_step = tf.train.get_or_create_global_step()
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
            valid_graph = SentenceMatchModelGraph(num_classes, word_vocab=word_vocab, char_vocab=char_vocab,
                                                  is_training=False, options=options)

        initializer = tf.global_variables_initializer()
        vars_ = {}
        for var in tf.global_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        print("Restoring model from " + best_path)
        saver.restore(sess, best_path)
        print("DONE!")
        acc,result = train.evaluation(sess, valid_graph, testDataStream, outpath=out_path,
                                              label_vocab=label_vocab)

        print("Accuracy for test set is : ",colored(acc, 'green'),"\n")

        # print(result['probs'])

        return acc,result
Ejemplo n.º 30
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.train_path, isLower=FLAGS.isLower)
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_2(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' %
                          (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' %
                                   (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                              format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                              format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
Ejemplo n.º 31
0
def main_func(_):
    print(FLAGS)
    save_path = FLAGS.train_dir + "tfFile/"
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    print("1. Loading WordVocab data...")
    wordVocab = Vocab()
    wordVocab.fromText_format3(FLAGS.train_dir, FLAGS.wordvec_path)
    sys.stdout.flush()

    prepare = Prepare()
    if FLAGS.hasTfrecords:
        print("2. Has Tfrecords File---Train---")
        total_lines = prepare.processTFrecords_hasDone(savePath=save_path, taskNumber=FLAGS.taskNumber)
    else:
        print("2. Start generating TFrecords File--train...")
        total_lines = prepare.processTFrecords(wordVocab, savePath=save_path, max_len=FLAGS.max_len,
                                               taskNumber=FLAGS.taskNumber)
    print("totalLines_train_0:", total_lines[0])
    print("totalLines_train_1:", total_lines[1])
    sys.stdout.flush()

    test_path = FLAGS.train_dir + FLAGS.test_path
    if FLAGS.hasTfrecords:
        print("3. Has TFrecords File--test...")
        totalLines_test = prepare.processTFrecords_test_hasDone(test_path=test_path, taskNumber=1)
    else:
        print("3. Start generating TFrecords File--test...")
        totalLines_test = prepare.processTFrecords_test(wordVocab,
                                                        savePath=save_path,
                                                        test_path=test_path,
                                                        max_len=FLAGS.max_len,
                                                        taskNumber=1)
    print("totalLines_test:", totalLines_test)
    sys.stdout.flush()

    print("4. Start loading TFrecords File...")
    taskNameList = []
    for i in range(FLAGS.taskNumber):
        string = FLAGS.train_dir + 'tfFile/train-' + str(i) + '.tfrecords'
        taskNameList.append(string)
    print("taskNameList: ", taskNameList)
    sys.stdout.flush()

    ################
    n = total_lines[0] / total_lines[1] + 1 if \
        total_lines[0] % total_lines[1] != 0 else \
        total_lines[0] / total_lines[1]
    print("n: ", n)
    num_batches_per_epoch_train_0 = int(total_lines[0] / FLAGS.batch_size) + 1 if \
        total_lines[0] % FLAGS.batch_size != 0 else int(
        total_lines[0] / FLAGS.batch_size)
    print("batch_numbers_train_0:", num_batches_per_epoch_train_0)
    batch_size_1 = FLAGS.batch_size / n

    num_batches_per_epoch_test = int(totalLines_test / FLAGS.batch_size) + 1 if \
        totalLines_test % FLAGS.batch_size != 0 else \
        int(totalLines_test / FLAGS.batch_size)
    print("batch_numbers_test:", num_batches_per_epoch_test)

    with tf.Graph().as_default():
        all_test = prepare.read_records(
            taskname=save_path + "test-0.tfrecords",
            max_len=FLAGS.max_len,
            epochs=FLAGS.num_epochs,
            batch_size=FLAGS.batch_size)

        all_train_0 = prepare.read_records(
            taskname=taskNameList[0],
            max_len=FLAGS.max_len,
            epochs=FLAGS.num_epochs,
            batch_size=FLAGS.batch_size)

        all_train_1 = prepare.read_records(
            taskname=taskNameList[1],
            max_len=FLAGS.max_len,
            epochs=FLAGS.num_epochs,
            batch_size=batch_size_1)

        print("Loading Model...")
        sys.stdout.flush()

        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            print("------------train model--------------")
            m_train = mtl_model.MTLModel(max_len=FLAGS.max_len,
                                         filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                                         num_filters=FLAGS.num_filters,
                                         num_hidden=FLAGS.num_hidden,
                                         word_vocab=wordVocab,
                                         l2_reg_lambda=FLAGS.l2_reg_lambda,
                                         learning_rate=FLAGS.learning_rate,
                                         adv=FLAGS.adv)
            m_train.build_train_op()
            print("\n\n")

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)
            init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            has_pre_trained_model = False
            out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, "runs"))

            print(out_dir)
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            else:
                print("continue training models")
                ckpt = tf.train.get_checkpoint_state(out_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    print("-------has_pre_trained_model--------")
                    print(ckpt.model_checkpoint_path)
                    has_pre_trained_model = True
                    sys.stdout.flush()

            checkpoint_prefix = os.path.join(out_dir, "model")
            if has_pre_trained_model:
                print("Restoring model from " + ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("DONE!")
                sys.stdout.flush()

            def dev_whole(num_batches_per_epoch_test):
                accuracies = []
                losses = []

                for j in range(num_batches_per_epoch_test):
                    input_y_test, input_left_test, input_centre_test = sess.run(
                        [all_test[0], all_test[1], all_test[2]])
                    loss, accuracy, loss_adv, loss_ce = sess.run(
                        [m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2],
                         m_train.tensors[1][3]],
                        feed_dict={
                            m_train.input_task_0: 0,
                            m_train.input_left_0: input_left_real_0,
                            m_train.input_right_0: input_centre_real_0,
                            m_train.input_y_0: input_y_real_0,
                            m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                            m_train.input_task_1: 1,
                            m_train.input_left_1: input_left_test,
                            m_train.input_right_1: input_centre_test,
                            m_train.input_y_1: input_y_test,
                        })
                    losses.append(loss_ce)
                    accuracies.append(accuracy)
                # print("specfic_prob: ", prob_test)
                sys.stdout.flush()
                return np.mean(np.array(losses)), np.mean(np.array(accuracies))

            def overfit(dev_accuracy):
                n = len(dev_accuracy)
                if n < 4:
                    return False
                for i in range(n - 4, n):
                    if dev_accuracy[i] > dev_accuracy[i - 1]:
                        return False
                return True

            dev_accuracy = []
            total_train_loss = []

            train_loss_0 = 0
            train_loss_1 = 0
            loss_task_0 = 0
            loss_task_1 = 0
            adv_0 = 0
            adv_1 = 0
            acc_1 = 0
            count = 0
            try:
                while not coord.should_stop():  ## for each epoch
                    for i in range(num_batches_per_epoch_train_0 * FLAGS.num_epochs):  ## for each batch
                        input_y_real_0, input_left_real_0, input_centre_real_0 = sess.run([all_train_0[0],
                                                                                           all_train_0[1],
                                                                                           all_train_0[2]])
                        input_y_real_1, input_left_real_1, input_centre_real_1 = sess.run([all_train_1[0],
                                                                                           all_train_1[1],
                                                                                           all_train_1[2]])

                        # acc, loss, loss_adv = m_train.tensors[0]
                        # _, current_step_0, loss_0, accuracy_0, loss_adv_0 = sess.run(
                        #     [m_train.train_ops[0][0], m_train.train_ops[0][1],
                        #      m_train.tensors[0][1], m_train.tensors[0][0], m_train.tensors[0][2]],
                        #     feed_dict={
                        #         m_train.input_task_0: 0,
                        #         m_train.input_left_0: input_left_real_0,
                        #         m_train.input_right_0: input_centre_real_0,
                        #         m_train.input_y_0: input_y_real_0,
                        #         m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                        #         m_train.input_task_1: 1,
                        #         m_train.input_left_1: input_left_real_1,
                        #         m_train.input_right_1: input_centre_real_1,
                        #         m_train.input_y_1: input_y_real_1,
                        #     })
                        # all_loss_adv += loss_adv_0
                        # train_acc += accuracy_0
                        # train_loss_0 += loss_0
                        # train_loss += loss_0
                        #
                        # _, current_step_1, loss_1, accuracy_1, loss_adv_1 = sess.run(
                        #     [m_train.train_ops[1][0], m_train.train_ops[1][1],
                        #      m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2]],
                        #     feed_dict={
                        #         m_train.input_task_0: 0,
                        #         m_train.input_left_0: input_left_real_0,
                        #         m_train.input_right_0: input_centre_real_0,
                        #         m_train.input_y_0: input_y_real_0,
                        #         m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                        #         m_train.input_task_1: 1,
                        #         m_train.input_left_1: input_left_real_1,
                        #         m_train.input_right_1: input_centre_real_1,
                        #         m_train.input_y_1: input_y_real_1,
                        #     })
                        _, loss_0, accuracy_0, loss_adv_0, loss_ce_0 = sess.run(
                            [m_train.train_ops[0],
                             m_train.tensors[0][1], m_train.tensors[0][0], m_train.tensors[0][2],
                             m_train.tensors[0][3]],
                            feed_dict={
                                m_train.input_task_0: 0,
                                m_train.input_left_0: input_left_real_0,
                                m_train.input_right_0: input_centre_real_0,
                                m_train.input_y_0: input_y_real_0,
                                m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                                m_train.input_task_1: 1,
                                m_train.input_left_1: input_left_real_1,
                                m_train.input_right_1: input_centre_real_1,
                                m_train.input_y_1: input_y_real_1,
                            })
                        train_loss_0 += loss_0
                        loss_task_0 += loss_ce_0
                        adv_0 += loss_adv_0

                        _, loss_1, accuracy_1, loss_adv_1, loss_ce_1 = sess.run(
                            [m_train.train_ops[1],
                             m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2],
                             m_train.tensors[1][3]],
                            feed_dict={
                                m_train.input_task_0: 0,
                                m_train.input_left_0: input_left_real_0,
                                m_train.input_right_0: input_centre_real_0,
                                m_train.input_y_0: input_y_real_0,
                                m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                                m_train.input_task_1: 1,
                                m_train.input_left_1: input_left_real_1,
                                m_train.input_right_1: input_centre_real_1,
                                m_train.input_y_1: input_y_real_1,
                            })
                        train_loss_1 += loss_1
                        loss_task_1 += loss_ce_1
                        adv_1 += loss_adv_1
                        acc_1 += accuracy_1

                        count += 1
                        if count % 500 == 0:
                            print("loss {}, acc {}".format(loss_0, accuracy_0))
                            print("--loss {}, acc {}, loss_adv {}, loss_ce {}--".format(loss_1, accuracy_1, loss_adv_1,
                                                                                        loss_ce_1))
                            sys.stdout.flush()

                        if count % num_batches_per_epoch_train_0 == 0 or \
                                count == num_batches_per_epoch_train_0 * FLAGS.num_epochs:

                            print("train_0: ", count / num_batches_per_epoch_train_0,
                                  " epoch, train_loss_0:", train_loss_0,
                                  "loss_task_0: ", loss_task_0,
                                  "adv_0: ", adv_0)

                            print(
                                "train_1: ", count / num_batches_per_epoch_train_0,
                                " epoch, train_loss_1: ", train_loss_1,
                                "loss_task_1: ", loss_task_1,
                                "adv_1: ", adv_1,
                                "acc_1 : ", acc_1 / num_batches_per_epoch_train_0)

                            total_train_loss.append(loss_task_1)
                            train_loss_0 = 0
                            train_loss_1 = 0
                            loss_task_0 = 0
                            loss_task_1 = 0
                            adv_0 = 0
                            adv_1 = 0
                            acc_1 = 0
                            sys.stdout.flush()

                            print("\n------------------Evaluation:-----------------------")
                            _, accuracy = dev_whole(num_batches_per_epoch_test)
                            dev_accuracy.append(accuracy)
                            print("--------Recently dev accuracy:--------")
                            print(dev_accuracy[-10:])

                            print("--------Recently loss_task_1:------")
                            print(total_train_loss[-10:])
                            if overfit(dev_accuracy):
                                print('-----Overfit!!----')
                                break
                            print("")
                            sys.stdout.flush()

                            # continue
                            path = saver.save(sess, checkpoint_prefix, global_step=count)

                            print("-------------------Saved model checkpoint to {}--------------------".format(path))
                            sys.stdout.flush()
                            output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                                            output_node_names=[
                                                                                                'task_1/prob'])
                            for node in output_graph_def.node:
                                if node.op == 'RefSwitch':
                                    node.op = 'Switch'
                                    for index in xrange(len(node.input)):
                                        if 'moving_' in node.input[index]:
                                            node.input[index] = node.input[index] + '/read'
                                elif node.op == 'AssignSub':
                                    node.op = 'Sub'
                                    if 'use_locking' in node.attr:
                                        del node.attr['use_locking']

                            with tf.gfile.GFile(FLAGS.train_dir + "runs/mtlmodel_specfic.pb", "wb") as f:
                                f.write(output_graph_def.SerializeToString())
                            print("%d ops in the final graph.\n" % len(output_graph_def.node))




            except tf.errors.OutOfRangeError:
                print("Done")
            finally:
                print("--------------------------finally---------------------------")
                print("current_step:", count)
                coord.request_stop()
                coord.join(threads)

            sess.close()