コード例 #1
0
ファイル: train_classifier.py プロジェクト: JZDBB/tensorflow
                   'summary_train_op', 'summary_test_op', 'summary_epoch_train_op']
    tensors = [loss, accuracy, train_op, global_step, image_place, label_place, dropout_param, summary_train_op,
               summary_test_op, summary_epoch_train_op]
    tensors_dictionary = dict(zip(tensors_key, tensors))

    ############################################
    ############ Run the Session ###############
    ############################################
    session_conf = tf.ConfigProto(
        allow_soft_placement=FLAGS.allow_soft_placement,
        log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(graph=graph, config=session_conf)

    with sess.as_default():
        saver = tf.train.Saver(max_to_keep=FLAGS.max_num_checkpoint)

        sess.run(tf.global_variables_initializer())

        ###################################################
        ############ Training / Evaluation ###############
        ###################################################
        train_evaluation.train(sess=sess, saver=saver, tensors=tensors_dictionary, data=data,
                               train_dir=FLAGS.train_dir,
                               finetuning=FLAGS.fine_tuning, online_test=FLAGS.online_test,
                               num_epochs=FLAGS.num_epochs, checkpoint_dir=FLAGS.checkpoint_dir,
                               batch_size=FLAGS.batch_size)

        # Test in the end of experiment.
        train_evaluation.evaluation(sess=sess, saver=saver, tensors=tensors_dictionary, data=data,
                                    checkpoint_dir=FLAGS.checkpoint_dir)
コード例 #2
0
    tensors_dictionary = dict(zip(tensors_key, tensors))

    ############################################
    ############ Run the Session ###############
    ############################################
    session_conf = tf.ConfigProto(
        allow_soft_placement=FLAGS.allow_soft_placement,
        log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(graph=graph, config=session_conf)

    with sess.as_default():
        # Run the saver.
        # 'max_to_keep' flag determines the maximum number of models that the tensorflow save and keep. default by TensorFlow = 5.
        saver = tf.train.Saver(max_to_keep=FLAGS.max_num_checkpoint)

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        ###################################################
        ############ Training / Evaluation ###############
        ###################################################
        train_evaluation.train(sess=sess, saver=saver, tensors=tensors_dictionary, data=data,
                               train_dir=FLAGS.train_dir,
                               finetuning=FLAGS.fine_tuning, online_test=FLAGS.online_test,
                               num_epochs=FLAGS.num_epochs, checkpoint_dir=FLAGS.checkpoint_dir,
                               batch_size=FLAGS.batch_size)

        # Test in the end of experiment.
        train_evaluation.evaluation(sess=sess, saver=saver, tensors=tensors_dictionary, data=data,
                                    checkpoint_dir=FLAGS.checkpoint_dir)