def run_training(): # you need to change the directories to yours. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/' train_dir = 'D:/tensorflow/mydata/cat_dog2/' #My dir--20170727-csq #logs_train_dir = '/home/kevin/tensorflow/cats_vs_dogs/logs/train/' logs_train_dir = 'D:/tensorflow/mylog/cat_dog2/' train, train_label = input_data.get_files(train_dir) print(train) print(train_label) train_batch, train_label_batch = input_data.get_batch( train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY) print(train_batch) print(train_label_batch) train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES) train_loss = model.losses(train_logits, train_label_batch) train_op = model.trainning(train_loss, learning_rate) train__acc = model.evaluation(train_logits, train_label_batch) summary_op = tf.summary.merge_all() print(summary_op) with tf.Session() as sess: train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for step in np.arange(MAX_STEP): print(step) if coord.should_stop(): break _, tra_loss, tra_acc = sess.run( [train_op, train_loss, train__acc]) if step % cnt_summary == 0: print( 'Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) summary_str = sess.run(summary_op) train_writer.add_summary(summary_str, step) if step % cnt_cache == 0 or (step + 1) == MAX_STEP: checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads) sess.close()
import input_data1 import model1 N_CLASSES = 2 IMG_W = 208 IMG_H = 208 BATCH_SIZE = 16 CAPACITY = 2000 MAX_STEP = 10000 learning_rate = 0.0001 train_dir = 'E:\Jupyter\catanddog\ALLPetImages' # train_dir = 'E:\PyCharmProject\mycatvsdog\PetImages' logs_train_dir = 'E:\PyCharmProject\mycatvsdog\log' train, train_label = input_data1.get_file(train_dir) dataset = input_data1.get_batch(train, train_label, BATCH_SIZE) # iterator = dataset.make_initializable_iterator() iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() train_batch = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_H, IMG_W, 3]) train_label_batch = tf.placeholder(tf.int32, shape=[BATCH_SIZE]) train_logits = model1.mynn_inference(train_batch, BATCH_SIZE, N_CLASSES) train_loss = model1.losses(train_logits, train_label_batch) train_op = model1.training(train_loss, learning_rate) train_acc = model1.evaluation(train_logits, train_label_batch) summary_op = tf.compat.v1.summary.merge_all() # 折线图