def evaluate(input_x, input_y):
    '''
    评价 文本分类
    :return
        result:预测的结果,哪一维更大
        accuracy:精确度
    '''
    graph = tf.Graph()
    with graph.as_default(), tf.Session() as sess:
        # 恢复模型
        features = tf.placeholder(tf.int32, [None, cnnc.SEQUENCE_LENGTH])
        labels = tf.placeholder(tf.int32, [None, cnnc.FLAGS.num_class])
        logits = cnnc.inference(features)
        predictions = tf.arg_max(logits, 1)
        correct_predictions = tf.equal(predictions, tf.arg_max(labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_predictions,
                                          dtype=tf.float32))
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("SUCESS")
        else:
            print("No checkpoint file found")

        result, accuracy = sess.run([predictions, accuracy], feed_dict={features: input_x, labels: input_y})

    return result, accuracy
def train():
    '''
    训练 cnnc 模型
    '''
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        # 传入数据所在的文件夹
        features, labels = cnnc.distorted_inputs(FLAGS.train_dir)


        logits = cnnc.inference(features)

        loss = cnnc.loss(logits, labels)

        train_op = cnnc.train(loss, global_step)

        # saver
        saver = tf.train.Saver(tf.all_variables())

        # 构建总结操作
        summary_op = tf.merge_all_summaries()

        # 初始化操作
        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))

        sess.run(init)

        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=sess.graph_def)

        for step in range(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            # 模型收敛
            assert not np.isnan(loss_value), "Model diverged with loss = NAN"

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)