def evaluate(self, sess, x, y): batch_test = batch_iter(x, y, self.pm.batch_size) for x_batch, y_batch in batch_test: seq_len = sequence(x_batch) feet_dict = self.feed_data(x_batch, y_batch, seq_len, 1.0) loss, accuracy = sess.run([self.loss, self.accuracy], feed_dict=feet_dict) return loss, accuracy
def train(model, pm, wordid, cat_to_id, dataid): tensorboard_dir = os.path.join(MEDIA_ROOT, 'tensorboard', 'text_rnn', make_dir_string(dataid, pm)) save_dir = os.path.join(CHECKPOINTS, 'text_rnn', make_dir_string(dataid, pm)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, 'best_validation') tf.summary.scalar('loss', model.loss) tf.summary.scalar('accuracy', model.accuracy) merged_summary = tf.summary.merge_all() writer = tf.summary.FileWriter(tensorboard_dir) saver = tf.train.Saver() session = tf.Session() session.run(tf.global_variables_initializer()) writer.add_graph(session.graph) x_train, y_train = process(pm.train_filename, wordid, cat_to_id, max_length=250) x_test, y_test = process(pm.test_filename, wordid, cat_to_id, max_length=250) for epoch in range(pm.num_epochs): print('Epoch:', epoch + 1) num_batchs = int((len(x_train) - 1) / pm.batch_size) + 1 batch_train = batch_iter(x_train, y_train, batch_size=pm.batch_size) for x_batch, y_batch in batch_train: seq_len = sequence(x_batch) feed_dict = model.feed_data(x_batch, y_batch, seq_len, pm.keep_prob) _, global_step, _summary, train_loss, train_accuracy = session.run( [ model.optimizer, model.global_step, merged_summary, model.loss, model.accuracy ], feed_dict=feed_dict) if global_step % 100 == 0: test_loss, test_accuracy = model.evaluate( session, x_test, y_test) print('global_step:', global_step, 'train_loss:', train_loss, 'train_accuracy:', train_accuracy, 'test_loss:', test_loss, 'test_accuracy:', test_accuracy) if global_step % num_batchs == 0: print('Saving Model...') saver.save(session, save_path, global_step=global_step) pm.learning_rate *= pm.lr_decay
def val_text(model, text_data, pm, wordid, cat_to_id, data_id): pre_label = [] # 预测值 session = tf.Session() session.run(tf.global_variables_initializer()) save_path = tf.train.latest_checkpoint( os.path.join(CHECKPOINTS, 'text_rnn', make_dir_string(data_id, pm)) ) # os.path.join(MEDIA_ROOT,'checkpoints','text_cnn',make_dir_string(data_id, pm)) saver = tf.train.Saver() flag = os.path.exists(save_path) saver.restore(sess=session, save_path=save_path) val_x = process_text(text_data, wordid, cat_to_id, max_length=250) seq_len = sequence(val_x) pre_lab = session.run(model.predict, feed_dict={ model.input_x: val_x, model.seq_length: seq_len, model.keep_prob: 1.0 }) # 将预测结果展示 return pre_lab[0]
def val(model, pm, wordid, cat_to_id, data_id): pre_label = [] label = [] session = tf.Session() session.run(tf.global_variables_initializer()) save_path = tf.train.latest_checkpoint( os.path.join(CHECKPOINTS, 'text_rnn', make_dir_string(data_id, pm)) ) # os.path.join(MEDIA_ROOT,'checkpoints','text_cnn',make_dir_string(data_id, pm)) saver = tf.train.Saver() saver.restore(sess=session, save_path=save_path) val_x, val_y = process(pm.val_filename, wordid, cat_to_id, max_length=250) batch_val = batch_iter(val_x, val_y, batch_size=64) for x_batch, y_batch in batch_val: seq_len = sequence(x_batch) pre_lab = session.run(model.predict, feed_dict={ model.input_x: x_batch, model.seq_length: seq_len, model.keep_prob: 1.0 }) pre_label.extend(pre_lab) label.extend(y_batch) return pre_label, label