def test():
    print("load settings:")
    print(FLAGS.learning_rate)

    print("load dataset:")
    mnist = logistic_regression.load_data_mnist()

    print("place holder")
    images_placeholder, labels_placeholder = ph.placeholder_inputs(FLAGS.batch_size, model.IMAGE_PIXELS, model.NUM_CLASSES)
    
    print("fill feed dict")
    fd.fill_feed_dict(mnist.train, images_placeholder, labels_placeholder, FLAGS.batch_size, FLAGS.fake_data)    
def test():
    print("load settings:")
    print(FLAGS.learning_rate)

    print("load dataset:")
    mnist = logistic_regression.load_data_mnist()

    print("place holder")
    images_placeholder, labels_placeholder = ph.placeholder_inputs(
        FLAGS.batch_size, model.IMAGE_PIXELS, model.NUM_CLASSES)

    print("fill feed dict")
    fd.fill_feed_dict(mnist.train, images_placeholder, labels_placeholder,
                      FLAGS.batch_size, FLAGS.fake_data)
def run_training():
    data_sets = logistic_regression.load_data_mnist(one_hot=False)

    hiddens = [FLAGS.hidden1, FLAGS.hidden2]

    # Graph実行用のSession
    with tf.Graph().as_default():
        # placeholderの生成
        images_placeholder, labels_placeholder = ph.placeholder_inputs(FLAGS.batch_size, model.IMAGE_PIXELS, model.NUM_CLASSES)

        # logitsの生成
        logits = model.inference(images_placeholder, hiddens)

        # lossの生成
        loss = model.loss(logits, labels_placeholder)

        # trainオペレーションの生成
        train_op = model.train(loss, FLAGS.learning_rate)

        # evaluationオペレーション
        eval_correct = model.evaluation(logits, labels_placeholder)

        # 可視化用サマリ(グラフのビルド時に初期化)
        summary_op = tf.merge_all_summaries()
        
        # modelの保存用オブジェクト
        saver = tf.train.Saver()

        # sessionオブジェクト
        sess = tf.Session()

        # 変数の初期化
        init = tf.initialize_all_variables()
        sess.run(init)

        # 可視化用の変数初期化
        summary_writer = tf.train.SummaryWriter(FLAGS.model_dir, graph_def=sess.graph_def)
        
        # 学習ステップの実行
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()

            # 現在のステップ用のデータを取得
            feed_dict = fd.fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder, FLAGS.batch_size, FLAGS.fake_data)

            # train_opを実行
            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            duration = time.time() - start_time

            # 100ステップごとにサマリーを出力
            if step%FLAGS.summary_step == 0:
                # stdout
                print('Step %d: loss = %.2f , %.3f sec' % (step, loss_value, duration))
                
                # 可視化用の変数アップデート
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)

            # 1000ステップごとにcheckpointを保存(最大ステップに達した際も保存)
            if (step+1)%FLAGS.checkpoint_step == 0 or (step+1) == FLAGS.max_steps:
                saver.save(sess, FLAGS.model_dir, global_step =step)
                
                # Evaluate model
                print('CheckPoint Train Evalation:')
                do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) 
                print('CheckPoint Valid Evalation:')
                do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation)
                print('CheckPoint Test Evalation:')
                do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
def run_training():
    data_sets = logistic_regression.load_data_mnist(one_hot=False)

    hiddens = [FLAGS.hidden1, FLAGS.hidden2]

    # Graph実行用のSession
    with tf.Graph().as_default():
        # placeholderの生成
        images_placeholder, labels_placeholder = ph.placeholder_inputs(
            FLAGS.batch_size, model.IMAGE_PIXELS, model.NUM_CLASSES)

        # logitsの生成
        logits = model.inference(images_placeholder, hiddens)

        # lossの生成
        loss = model.loss(logits, labels_placeholder)

        # trainオペレーションの生成
        train_op = model.train(loss, FLAGS.learning_rate)

        # evaluationオペレーション
        eval_correct = model.evaluation(logits, labels_placeholder)

        # 可視化用サマリ(グラフのビルド時に初期化)
        summary_op = tf.merge_all_summaries()

        # modelの保存用オブジェクト
        saver = tf.train.Saver()

        # sessionオブジェクト
        sess = tf.Session()

        # 変数の初期化
        init = tf.initialize_all_variables()
        sess.run(init)

        # 可視化用の変数初期化
        summary_writer = tf.train.SummaryWriter(FLAGS.model_dir,
                                                graph_def=sess.graph_def)

        # 学習ステップの実行
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()

            # 現在のステップ用のデータを取得
            feed_dict = fd.fill_feed_dict(data_sets.train, images_placeholder,
                                          labels_placeholder, FLAGS.batch_size,
                                          FLAGS.fake_data)

            # train_opを実行
            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            duration = time.time() - start_time

            # 100ステップごとにサマリーを出力
            if step % FLAGS.summary_step == 0:
                # stdout
                print('Step %d: loss = %.2f , %.3f sec' %
                      (step, loss_value, duration))

                # 可視化用の変数アップデート
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)

            # 1000ステップごとにcheckpointを保存(最大ステップに達した際も保存)
            if (step + 1) % FLAGS.checkpoint_step == 0 or (
                    step + 1) == FLAGS.max_steps:
                saver.save(sess, FLAGS.model_dir, global_step=step)

                # Evaluate model
                print('CheckPoint Train Evalation:')
                do_eval(sess, eval_correct, images_placeholder,
                        labels_placeholder, data_sets.train)
                print('CheckPoint Valid Evalation:')
                do_eval(sess, eval_correct, images_placeholder,
                        labels_placeholder, data_sets.validation)
                print('CheckPoint Test Evalation:')
                do_eval(sess, eval_correct, images_placeholder,
                        labels_placeholder, data_sets.test)