def main(): data_loader = InputHelper() data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file, FLAGS.data_dir + '/') FLAGS.vocab_size = data_loader.vocab_size FLAGS.n_classes = data_loader.n_classes model = BiRNN(FLAGS.rnn_size, FLAGS.layer_size, FLAGS.vocab_size, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.n_classes, FLAGS.grad_clip) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) while True: # x = raw_input('请输入一个地址:\n') x = '' x = [data_loader.transform_raw(x, FLAGS.sequence_length)] labels = model.inference(sess, data_loader.labels, x) print(labels)
def main(): data_loader = InputHelper() data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file, FLAGS.data_dir + '/') FLAGS.vocab_size = data_loader.vocab_size FLAGS.n_classes = data_loader.n_classes wl = load_wl() # Define specified Model model = BiRNN(embedding_size=FLAGS.embedding_size, rnn_size=FLAGS.rnn_size, layer_size=FLAGS.layer_size, vocab_size=FLAGS.vocab_size, attn_size=FLAGS.attn_size, sequence_length=FLAGS.sequence_length, n_classes=FLAGS.n_classes, grad_clip=FLAGS.grad_clip, learning_rate=FLAGS.learning_rate) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) now = 0 for file in open('./data/total_txt_img_cat.list'): word_list = {} arr = np.zeros([len(wl), 50]) lab = file.split('\t')[2] for line in open('./data/text_seqs/' + file.split()[0] + '.xml'): seq = line.split('\t')[0] x, w = data_loader.transform_raw(seq, FLAGS.sequence_length) _, out_features = model.inference(sess, data_loader.labels, [x]) for i, j in enumerate(w): punc = '[,.!\'%*+-/=><]' j = re.sub(punc, '', j) if j in word_list: word_list[j] += out_features[i] else: word_list[j] = out_features[i] count = 0 for w in word_list: if w in wl: arr[wl[w]] = word_list[w] count += 1 print('now:', now, 'count:', count, 'shape:', arr.shape) s = str(now) while len(s) < 4: s = '0' + s np.save('./text_lstm/text_' + s + '_' + lab.strip() + '.npy', arr) now += 1
def main(): data_loader = InputHelper('data/stop_words.pkl') data_loader.load_dictionary(FLAGS.data_dir + '/dictionary') FLAGS.vocab_size = data_loader.vocab_size FLAGS.n_classes = data_loader.n_classes # Define specified Model model = AttentionBiRNN(FLAGS.embedding_size, FLAGS.rnn_size, FLAGS.layer_size, FLAGS.vocab_size, FLAGS.attn_size, FLAGS.sequence_length, FLAGS.n_classes, FLAGS.grad_clip, FLAGS.learning_rate) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # while True: # x = raw_input('请输入一个地址:\n') with open('./data/test.txt', 'r') as f: lines = f.readlines() for line in lines: text = line.split('\t') l = text[0] try: x = [data_loader.transform_raw(l, FLAGS.sequence_length)] words = data_loader.words labels, alphas, _ = model.inference(sess, data_loader.labels, x) print(labels, text[1].replace('\n', ''), len(words)) words_weights = [] for word, alpha in zip(words, alphas[0] / alphas[0][0:len(words)].max()): words_weights.append(word + ':' + str(alpha)) print(str(words_weights).decode('unicode-escape')) except: print(l)