def evaluate(input_x, input_y): ''' 评价 文本分类 :return result:预测的结果,哪一维更大 accuracy:精确度 ''' graph = tf.Graph() with graph.as_default(), tf.Session() as sess: # 恢复模型 features = tf.placeholder(tf.int32, [None, cnnc.SEQUENCE_LENGTH]) labels = tf.placeholder(tf.int32, [None, cnnc.FLAGS.num_class]) logits = cnnc.inference(features) predictions = tf.arg_max(logits, 1) correct_predictions = tf.equal(predictions, tf.arg_max(labels, 1)) accuracy = tf.reduce_mean(tf.cast(correct_predictions, dtype=tf.float32)) saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print("SUCESS") else: print("No checkpoint file found") result, accuracy = sess.run([predictions, accuracy], feed_dict={features: input_x, labels: input_y}) return result, accuracy
def train(): ''' 训练 cnnc 模型 ''' with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False) # 传入数据所在的文件夹 features, labels = cnnc.distorted_inputs(FLAGS.train_dir) logits = cnnc.inference(features) loss = cnnc.loss(logits, labels) train_op = cnnc.train(loss, global_step) # saver saver = tf.train.Saver(tf.all_variables()) # 构建总结操作 summary_op = tf.merge_all_summaries() # 初始化操作 init = tf.initialize_all_variables() sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) sess.run(init) tf.train.start_queue_runners(sess=sess) summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=sess.graph_def) for step in range(FLAGS.max_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time # 模型收敛 assert not np.isnan(loss_value), "Model diverged with loss = NAN" if step % 10 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) if step % 100 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, "model.ckpt") saver.save(sess, checkpoint_path, global_step=step)