def main():
    data_loader = InputHelper()
    data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file,
                                  FLAGS.data_dir + '/')
    FLAGS.vocab_size = data_loader.vocab_size
    FLAGS.n_classes = data_loader.n_classes

    model = BiRNN(FLAGS.rnn_size, FLAGS.layer_size, FLAGS.vocab_size,
                  FLAGS.batch_size, FLAGS.sequence_length, FLAGS.n_classes,
                  FLAGS.grad_clip)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        while True:
            # x = raw_input('请输入一个地址:\n')
            x = ''
            x = [data_loader.transform_raw(x, FLAGS.sequence_length)]

            labels = model.inference(sess, data_loader.labels, x)
            print(labels)
def main():

    data_loader = InputHelper(log=log)

    data_loader.load_embedding(FLAGS.embedding_file, FLAGS.embedding_size)

    data_loader.load_label_dictionary(FLAGS.label_dic)

    x, y, x_w_p, x_s_p = data_loader.load_valid(FLAGS.valid_file,
                                                FLAGS.interaction_rounds,
                                                FLAGS.sequence_length)

    FLAGS.embeddings = data_loader.embeddings

    FLAGS.vocab_size = len(data_loader.word2idx)

    FLAGS.n_classes = len(data_loader.label_dictionary)

    model = BiRNN(embedding_size=FLAGS.embedding_size,
                  rnn_size=FLAGS.rnn_size,
                  layer_size=FLAGS.layer_size,
                  vocab_size=FLAGS.vocab_size,
                  attn_size=FLAGS.attn_size,
                  sequence_length=FLAGS.sequence_length,
                  n_classes=FLAGS.n_classes,
                  interaction_rounds=FLAGS.interaction_rounds,
                  batch_size=FLAGS.batch_size,
                  embeddings=FLAGS.embeddings,
                  grad_clip=FLAGS.grad_clip,
                  learning_rate=FLAGS.learning_rate)

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables())

        ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)

        model_path = FLAGS.save_dir + '/model.ckpt-45'

        if ckpt and ckpt.model_checkpoint_path:

            saver.restore(sess, model_path)

        labels = model.inference(sess, y, x, x_w_p, x_s_p)

        corrcet_num = 0

        for i in range(len(labels)):

            if labels[i] == y[i]:

                corrcet_num += 1

        print('eval_acc = {:.3f}'.format(corrcet_num * 1.0 / len(labels)))

        data_loader.output_result(labels, FLAGS.valid_file, FLAGS.result_file)
示例#3
0
def main():
    data_loader = InputHelper()
    data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file,
                                  FLAGS.data_dir + '/')
    FLAGS.vocab_size = data_loader.vocab_size
    FLAGS.n_classes = data_loader.n_classes
    wl = load_wl()
    # Define specified Model
    model = BiRNN(embedding_size=FLAGS.embedding_size,
                  rnn_size=FLAGS.rnn_size,
                  layer_size=FLAGS.layer_size,
                  vocab_size=FLAGS.vocab_size,
                  attn_size=FLAGS.attn_size,
                  sequence_length=FLAGS.sequence_length,
                  n_classes=FLAGS.n_classes,
                  grad_clip=FLAGS.grad_clip,
                  learning_rate=FLAGS.learning_rate)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        now = 0
        for file in open('./data/total_txt_img_cat.list'):
            word_list = {}
            arr = np.zeros([len(wl), 50])
            lab = file.split('\t')[2]
            for line in open('./data/text_seqs/' + file.split()[0] + '.xml'):
                seq = line.split('\t')[0]
                x, w = data_loader.transform_raw(seq, FLAGS.sequence_length)
                _, out_features = model.inference(sess, data_loader.labels,
                                                  [x])
                for i, j in enumerate(w):
                    punc = '[,.!\'%*+-/=><]'
                    j = re.sub(punc, '', j)
                    if j in word_list:
                        word_list[j] += out_features[i]
                    else:
                        word_list[j] = out_features[i]
            count = 0
            for w in word_list:
                if w in wl:
                    arr[wl[w]] = word_list[w]
                    count += 1
            print('now:', now, 'count:', count, 'shape:', arr.shape)
            s = str(now)
            while len(s) < 4:
                s = '0' + s
            np.save('./text_lstm/text_' + s + '_' + lab.strip() + '.npy', arr)
            now += 1