def main(): data_loader = InputHelper() data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file, FLAGS.data_dir + '/') FLAGS.vocab_size = data_loader.vocab_size FLAGS.n_classes = data_loader.n_classes model = BiRNN(FLAGS.rnn_size, FLAGS.layer_size, FLAGS.vocab_size, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.n_classes, FLAGS.grad_clip) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) while True: # x = raw_input('请输入一个地址:\n') x = '' x = [data_loader.transform_raw(x, FLAGS.sequence_length)] labels = model.inference(sess, data_loader.labels, x) print(labels)
def main(): data_loader = InputHelper(log=log) data_loader.load_embedding(FLAGS.embedding_file, FLAGS.embedding_size) data_loader.load_label_dictionary(FLAGS.label_dic) x, y, x_w_p, x_s_p = data_loader.load_valid(FLAGS.valid_file, FLAGS.interaction_rounds, FLAGS.sequence_length) FLAGS.embeddings = data_loader.embeddings FLAGS.vocab_size = len(data_loader.word2idx) FLAGS.n_classes = len(data_loader.label_dictionary) model = BiRNN(embedding_size=FLAGS.embedding_size, rnn_size=FLAGS.rnn_size, layer_size=FLAGS.layer_size, vocab_size=FLAGS.vocab_size, attn_size=FLAGS.attn_size, sequence_length=FLAGS.sequence_length, n_classes=FLAGS.n_classes, interaction_rounds=FLAGS.interaction_rounds, batch_size=FLAGS.batch_size, embeddings=FLAGS.embeddings, grad_clip=FLAGS.grad_clip, learning_rate=FLAGS.learning_rate) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) model_path = FLAGS.save_dir + '/model.ckpt-45' if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, model_path) labels = model.inference(sess, y, x, x_w_p, x_s_p) corrcet_num = 0 for i in range(len(labels)): if labels[i] == y[i]: corrcet_num += 1 print('eval_acc = {:.3f}'.format(corrcet_num * 1.0 / len(labels))) data_loader.output_result(labels, FLAGS.valid_file, FLAGS.result_file)
def main(): data_loader = InputHelper() data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file, FLAGS.data_dir + '/') FLAGS.vocab_size = data_loader.vocab_size FLAGS.n_classes = data_loader.n_classes wl = load_wl() # Define specified Model model = BiRNN(embedding_size=FLAGS.embedding_size, rnn_size=FLAGS.rnn_size, layer_size=FLAGS.layer_size, vocab_size=FLAGS.vocab_size, attn_size=FLAGS.attn_size, sequence_length=FLAGS.sequence_length, n_classes=FLAGS.n_classes, grad_clip=FLAGS.grad_clip, learning_rate=FLAGS.learning_rate) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) now = 0 for file in open('./data/total_txt_img_cat.list'): word_list = {} arr = np.zeros([len(wl), 50]) lab = file.split('\t')[2] for line in open('./data/text_seqs/' + file.split()[0] + '.xml'): seq = line.split('\t')[0] x, w = data_loader.transform_raw(seq, FLAGS.sequence_length) _, out_features = model.inference(sess, data_loader.labels, [x]) for i, j in enumerate(w): punc = '[,.!\'%*+-/=><]' j = re.sub(punc, '', j) if j in word_list: word_list[j] += out_features[i] else: word_list[j] = out_features[i] count = 0 for w in word_list: if w in wl: arr[wl[w]] = word_list[w] count += 1 print('now:', now, 'count:', count, 'shape:', arr.shape) s = str(now) while len(s) < 4: s = '0' + s np.save('./text_lstm/text_' + s + '_' + lab.strip() + '.npy', arr) now += 1