def main(): args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu ###################### # directory preparation filewriter_path = args.tensorboard_dir checkpoint_path = args.checkpoint_dir test_mkdir(filewriter_path) test_mkdir(checkpoint_path) ###################### # data preparation train_file = os.path.join(args.list_dir, "train.txt") val_file = os.path.join(args.list_dir, "val.txt") train_generator = ImageDataGenerator(train_file, shuffle=True) val_generator = ImageDataGenerator(val_file, shuffle=False) batch_size = args.batch_size train_batches_per_epoch = train_generator.data_size val_batches_per_epoch = val_generator.data_size ###################### # model graph preparation patch_height = args.patch_size patch_width = args.patch_size batch_size = args.batch_size # TF placeholder for graph input leftx = tf.placeholder(tf.float32, shape=[batch_size, patch_height, patch_width, 1]) rightx_pos = tf.placeholder( tf.float32, shape=[batch_size, patch_height, patch_width, 1]) rightx_neg = tf.placeholder( tf.float32, shape=[batch_size, patch_height, patch_width, 1]) # Initialize model left_model = NET(leftx, input_patch_size=patch_height, batch_size=batch_size) right_model_pos = NET(rightx_pos, input_patch_size=patch_height, batch_size=batch_size) right_model_neg = NET(rightx_neg, input_patch_size=patch_height, batch_size=batch_size) featuresl = tf.squeeze(left_model.features, [1, 2]) featuresr_pos = tf.squeeze(right_model_pos.features, [1, 2]) featuresr_neg = tf.squeeze(right_model_neg.features, [1, 2]) # Op for calculating cosine distance/dot product with tf.name_scope("correlation"): cosine_pos = tf.reduce_sum(tf.multiply(featuresl, featuresr_pos), axis=-1) cosine_neg = tf.reduce_sum(tf.multiply(featuresl, featuresr_neg), axis=-1) # Op for calculating the loss with tf.name_scope("hinge_loss"): margin = tf.ones(shape=[batch_size], dtype=tf.float32) * args.margin loss = tf.maximum(0.0, margin - cosine_pos + cosine_neg) loss = tf.reduce_mean(loss) # Train op with tf.name_scope("train"): var_list = tf.trainable_variables() for var in var_list: print "{}: {}".format(var.name, var.shape) # Get gradients of all trainable variables gradients = tf.gradients(loss, var_list) gradients = list(zip(gradients, var_list)) # Create optimizer and apply gradient descent with momentum to the trainable variables optimizer = tf.train.MomentumOptimizer(args.learning_rate, args.beta) train_op = optimizer.apply_gradients(grads_and_vars=gradients) # summary Ops for tensorboard visualization with tf.name_scope("training_metric"): training_summary = [] # Add loss to summary training_summary.append(tf.summary.scalar('hinge_loss', loss)) # Merge all summaries together training_merged_summary = tf.summary.merge(training_summary) # validation loss with tf.name_scope("val_metric"): val_summary = [] val_loss = tf.placeholder(tf.float32, []) # Add val loss to summary val_summary.append(tf.summary.scalar('val_hinge_loss', val_loss)) val_merged_summary = tf.summary.merge(val_summary) # Initialize the FileWriter writer = tf.summary.FileWriter(filewriter_path) # Initialize an saver for store model checkpoints saver = tf.train.Saver(max_to_keep=10) ###################### # DO training # Start Tensorflow session with tf.Session(config=tf.ConfigProto( log_device_placement=False, \ allow_soft_placement=True, \ gpu_options=tf.GPUOptions(allow_growth=True))) as sess: # Initialize all variables sess.run(tf.global_variables_initializer()) # resume from checkpoint or not if args.resume is None: # Add the model graph to TensorBoard before initial training writer.add_graph(sess.graph) else: saver.restore(sess, args.resume) print "training_batches_per_epoch: {}, val_batches_per_epoch: {}.".format(\ train_batches_per_epoch, val_batches_per_epoch) print("{} Start training...".format(datetime.now())) print("{} Open Tensorboard at --logdir {}".format( datetime.now(), filewriter_path)) # Loop training for epoch in range(args.start_epoch, args.end_epoch): print("{} Epoch number: {}".format(datetime.now(), epoch + 1)) for batch in tqdm(range(train_batches_per_epoch)): # Get a batch of data batch_left, batch_right_pos, batch_right_neg = train_generator.next_batch( batch_size) # And run the training op sess.run(train_op, feed_dict={ leftx: batch_left, rightx_pos: batch_right_pos, rightx_neg: batch_right_neg }) # Generate summary with the current batch of data and write to file if (batch + 1) % args.print_freq == 0: s = sess.run(training_merged_summary, feed_dict={ leftx: batch_left, rightx_pos: batch_right_pos, rightx_neg: batch_right_neg }) writer.add_summary(s, epoch * train_batches_per_epoch + batch) if (epoch + 1) % args.save_freq == 0: print("{} Saving checkpoint of model...".format( datetime.now())) # save checkpoint of the model checkpoint_name = os.path.join( checkpoint_path, 'model_epoch' + str(epoch + 1) + '.ckpt') save_path = saver.save(sess, checkpoint_name) if (epoch + 1) % args.val_freq == 0: # Validate the model on the entire validation set print("{} Start validation".format(datetime.now())) val_ls = 0. for _ in tqdm(range(val_batches_per_epoch)): batch_left, batch_right_pos, batch_right_neg = val_generator.next_batch( batch_size) result = sess.run(loss, feed_dict={ leftx: batch_left, rightx_pos: batch_right_pos, rightx_neg: batch_right_neg }) val_ls += result val_ls = val_ls / (1. * val_batches_per_epoch) print 'validation loss: {}'.format(val_ls) s = sess.run(val_merged_summary, feed_dict={val_loss: np.float32(val_ls)}) writer.add_summary(s, train_batches_per_epoch * (epoch + 1)) # Reset the file pointer of the image data generator val_generator.reset_pointer() train_generator.reset_pointer()
print("{} Start validation".format(datetime.now())) test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch): batch_tx, batch_ty = val_generator.next_batch(batch_size) acc = sess.run(accuracy, feed_dict={ x: batch_tx, y: batch_ty, keep_prob: 1. }) test_acc += acc test_count += 1 test_acc /= test_count print("{} Validation Accuracy = {:.4f}".format(datetime.now(), test_acc)) # Reset the file pointer of the image data generator val_generator.reset_pointer() train_generator.reset_pointer() print("{} Saving checkpoint of model...".format(datetime.now())) #save checkpoint of the model checkpoint_name = os.path.join( checkpoint_path, 'model_epoch' + str(epoch + 1) + '.ckpt') save_path = saver.save(sess, checkpoint_name) print("{} Model checkpoint saved at {}".format(datetime.now(), checkpoint_name))
def train(batch_size, learning_rate, conv, fc, dropout_rate, additional): x = tf.placeholder(tf.float32, [batch_size, input_size, input_size, 3]) y = tf.placeholder(tf.float32, [None, num_classes]) keep_prob = tf.placeholder(tf.float32) # If lesser number of convolutions are to be used, the first fully connected # layer weights also need to be learned if conv < 5: train_layers.append('fc6') # Load the model with the desired parameters model = AlexNet(x, keep_prob, num_classes, train_layers, fc, conv, additional) score = model.fc8 # List of trainable variables of the layers we want to train var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers] # Op for calculating the loss with tf.name_scope('cross_entropy'): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=score, labels=y)) # Train op with tf.name_scope('train'): gradients = tf.gradients(loss, var_list) gradients = list(zip(gradients, var_list)) optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = optimizer.apply_gradients(grads_and_vars=gradients) # Add gradients to summary for gradient, var in gradients: tf.summary.histogram(var.name + '/gradient', gradient) # Add the variables we train to the summary for var in var_list: tf.summary.histogram(var.name, var) # Add the loss to the summary tf.summary.scalar('cross_entropy', loss) # Op for the accuracy of the model with tf.name_scope('accuracy'): correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) tf.summary.scalar('accuracy', accuracy) # Merge all summaries together merged_summary = tf.summary.merge_all() # Initialize the FileWriter writer = tf.summary.FileWriter(filewriter_path) saver = tf.train.Saver() train_generator = ImageDataGenerator(train_file, horizontal_flip=True, shuffle=True) val_generator = ImageDataGenerator(val_file, shuffle=False) # Get the number of training / validation steps per epoch train_batches_per_epoch = np.floor(train_generator.data_size / batch_size).astype(np.int16) val_batches_per_epoch = np.floor(val_generator.data_size / batch_size).astype(int) # sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) sess = tf.Session() sess.run(tf.global_variables_initializer()) # saver = tf.train.import_meta_graph('/tmp/finetune_alexnet/model_epoch0.ckpt.meta') # saver.restore(sess, '/tmp/finetune_alexnet/model_epoch0.ckpt') # Comment this line and uncomment the above ones to reuse the model after a checkpoint # for faster training model.load_initial_weights(sess) print("{} Restored model...".format(datetime.now())) test_acc_prev = 0 test_count = 0 for _ in range(val_batches_per_epoch): batch_tx, batch_ty = val_generator.next_batch(batch_size) acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, keep_prob: 1.}) test_acc_prev += acc test_count += 1 test_acc_prev /= test_count test_acc = test_acc_prev print("{} Initial Validation Accuracy = {:.4f}".format(datetime.now(), test_acc_prev)) writer.add_graph(sess.graph) print('{} Start training...'.format(datetime.now())) print('{} Open TensorBoard at --logdir {}'.format(datetime.now(), filewriter_path)) for epoch in range(num_epochs): print('{} Epoch number: {}'.format(datetime.now(), epoch + 1)) step = 1 while step < train_batches_per_epoch: print('{} Step number: {}'.format(datetime.now(), step)) batch_xs, batch_ys = train_generator.next_batch(batch_size) sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, keep_prob: dropout_rate}) if step % display_step == 0: s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, keep_prob: dropout_rate}) writer.add_summary(s, epoch * train_batches_per_epoch + step) step += 1 # Validate the model on the entire validation set print("{} Start validation".format(datetime.now())) test_acc = 0. test_count = 0 for ind in range(val_batches_per_epoch): print('{} Valid batch number: {}'.format(datetime.now(), ind)) batch_tx, batch_ty = val_generator.next_batch(batch_size) acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, keep_prob: 1.}) test_acc += acc test_count += 1 test_acc /= test_count print("{} Validation Accuracy = {:.4f}".format(datetime.now(), test_acc)) if test_acc > test_acc_prev: print("{} Saving checkpoint of model...".format(datetime.now())) #save checkpoint of the model checkpoint_name = os.path.join(checkpoint_path, 'model_epoch'+str(epoch)+'.ckpt') save_path = saver.save(sess, checkpoint_name) print("{} Model checkpoint saved at {}".format(datetime.now(), checkpoint_name)) if abs(test_acc - test_acc_prev) < 0.05: print("Early stopping.... exiting") break test_acc_prev = test_acc # Reset the file pointer of the image data generator val_generator.reset_pointer() train_generator.reset_pointer() return test_acc
y: batch_ty, keep_var: 1. }) #print pre_labels test_pre_label += pre_labels.tolist() #test_true_label += tf.argmax(batch_ty,1).tolist() #pdb.set_trace() test_pre_label += rest_test_pre_label test_acc = accuracy_score(val_generator.labels, test_pre_label) print("{} Iter {}: Testing Accuracy = {:.4f}".format( datetime.now(), step, test_acc)) #all_test_acc.append(test_acc) print("F1 score = {:.4f}".format( f1_score(val_generator.labels, test_pre_label, average='macro'))) print("Confusionmatrix =") print(" {} ".format( confusion_matrix(val_generator.labels, test_pre_label))) # Reset the file pointer of the image data generator val_generator.reset_pointer() #2attention train_generator.reset_pointer() #save model temp = '%05d' % step model_path_1 = model_path + temp + '.ckpt' #save_path = saver.save(sess, model_path_1) #print("Model saved in file: {}".format(save_path))
def main(_): # Create training directories now = datetime.datetime.now() train_dir_name = now.strftime('resnet_%Y%m%d_%H%M%S') train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name) checkpoint_dir = os.path.join(train_dir, 'checkpoint') tensorboard_dir = os.path.join(train_dir, 'tensorboard') tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir) if not os.path.isdir(train_dir): os.mkdir(train_dir) if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) # Write flags to txt flags_file_path = os.path.join(train_dir, 'flags.txt') flags_file = open(flags_file_path, 'w') flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate)) flags_file.write('resnet_depth={}\n'.format(FLAGS.resnet_depth)) flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs)) flags_file.write('batch_size={}\n'.format(FLAGS.batch_size)) flags_file.write('train_layers={}\n'.format(FLAGS.train_layers)) flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale)) flags_file.write('tensorboard_root_dir={}\n'.format(FLAGS.tensorboard_root_dir)) flags_file.write('log_step={}\n'.format(FLAGS.log_step)) flags_file.close() # Placeholders x = tf.placeholder(tf.float32, [FLAGS.batch_size, 224, 224, 3]) y = tf.placeholder(tf.float32, [None, FLAGS.num_classes]) is_training = tf.placeholder('bool', []) # Model train_layers = FLAGS.train_layers.split(',') model = ResNetModel(is_training, depth=FLAGS.resnet_depth, num_classes=FLAGS.num_classes) loss = model.loss(x, y) train_op = model.optimize(FLAGS.learning_rate, train_layers) # Training accuracy of the model correct_pred = tf.equal(tf.argmax(model.prob, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # Summaries tf.summary.scalar('train_loss', loss) tf.summary.scalar('train_accuracy', accuracy) merged_summary = tf.summary.merge_all() train_writer = tf.summary.FileWriter(tensorboard_train_dir) val_writer = tf.summary.FileWriter(tensorboard_val_dir) saver = tf.train.Saver() # Batch preprocessors multi_scale = FLAGS.multi_scale.split(',') if len(multi_scale) == 2: multi_scale = [int(multi_scale[0]), int(multi_scale[1])] else: multi_scale = None # Initalize the data generator seperately for the training and validation set train_generator = ImageDataGenerator(FLAGS.training_file, horizontal_flip = True, shuffle = True) val_generator = ImageDataGenerator(FLAGS.val_file, shuffle = False) #train_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes, # output_size=[224, 224], horizontal_flip=True, shuffle=True, multi_scale=multi_scale) #val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[224, 224]) # Get the number of training/validation steps per epoch train_batches_per_epoch = np.floor(train_generator.data_size / FLAGS.batch_size).astype(np.int16) val_batches_per_epoch = np.floor(val_generator.data_size / FLAGS.batch_size).astype(np.int16) # Get the number of training/validation steps per epoch #train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) #val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) os.environ["CUDA_VISIBLE_DEVICES"] = "1" gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) train_writer.add_graph(sess.graph) # Load the pretrained weights model.load_original_weights(sess, skip_layers=train_layers) # Directly restore (your model should be exactly the same with checkpoint) # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt") print("{} Start training...".format(datetime.datetime.now())) print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir)) for epoch in range(FLAGS.num_epochs): print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1)) step = 1 # Start training while step < train_batches_per_epoch: # Get a batch of images and labels batch_xs, batch_ys = train_generator.next_batch(FLAGS.batch_size) #batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size) sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, is_training: True}) # Logging if step % FLAGS.log_step == 0: s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, is_training: False}) train_writer.add_summary(s, epoch * train_batches_per_epoch + step) step += 1 # Epoch completed, start validation print("{} Start validation".format(datetime.datetime.now())) test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch): batch_tx, batch_ty = val_generator.next_batch(FLAGS.batch_size) #batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size) acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, is_training: False}) test_acc += acc test_count += 1 test_acc /= test_count s = tf.Summary(value=[ tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc) ]) val_writer.add_summary(s, epoch+1) print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc)) # Reset the file pointer of the image data generator val_generator.reset_pointer() train_generator.reset_pointer() print("{} Saving checkpoint of model...".format(datetime.datetime.now())) #save checkpoint of the model checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch'+str(epoch+1)+'.ckpt') save_path = saver.save(sess, checkpoint_path) print("{} Model checkpoint saved at {}".format(datetime.datetime.now(), checkpoint_path))
def main(unused_argv): if FLAGS.job_name is None or FLAGS.job_name == '': raise ValueError('Must specify an expilict job_name') else: print('job_name: %s' % FLAGS.job_name) if FLAGS.task_index is None or FLAGS.task_index == '': raise ValueError('Must specify an explicit task_index') else: print('task_index:%s' % FLAGS.task_index) ps_spec = FLAGS.ps_hosts.split(',') worker_spec = FLAGS.worker_hosts.split(',') num_worker = len(worker_spec) cluster = tf.train.ClusterSpec({'ps': ps_spec, 'worker': worker_spec}) kill_ps_queue = create_done_queue(num_worker) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == 'ps': # server.join() with tf.Session(server.target) as sess: for i in range(num_worker): sess.run(kill_ps_queue.dequeue()) return is_chief = (FLAGS.task_index == 0) if FLAGS.use_gpu: worker_device = '/job:worker/task:%d/gpu:%d' % (FLAGS.task_index, FLAGS.gpu_id) else: worker_device = '/job:worker/task:%d/cpu:0' % FLAGS.task_index with tf.device( tf.train.replica_device_setter(worker_device=worker_device, ps_device='/job:ps/cpu:0', cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False) x = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x') y = tf.placeholder(tf.float32, [None, FLAGS.n_classes], name='y') keep_prob = tf.placeholder(tf.float32, name='kp') model = AlexNet(x, keep_prob, FLAGS.n_classes) score = model.fc3 cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=score)) tf.summary.scalar('loss', cross_entropy) opt = get_optimizer('Adam', FLAGS.learning_rate) if FLAGS.sync_replicas: replicas_to_aggregate = num_worker opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=replicas_to_aggregate, total_num_replicas=num_worker, use_locking=False, name='sync_replicas') train_op = opt.minimize(cross_entropy, global_step=global_step) correct_prediction = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuary', accuracy) if FLAGS.sync_replicas: local_init_op = opt.local_step_init_op if is_chief: local_init_op = opt.chief_init_op ready_for_local_init_op = opt.ready_for_local_init_op chief_queue_runner = opt.get_chief_queue_runner() init_token_op = opt.get_init_tokens_op() init_op = tf.global_variables_initializer() kill_ps_enqueue_op = kill_ps_queue.enqueue(1) summary_op = tf.summary.merge_all() writer = tf.summary.FileWriter(FLAGS.logdir) saver = tf.train.Saver() # train_dir = tempfile.mkdtemp() if FLAGS.sync_replicas: sv = tf.train.Supervisor( is_chief=is_chief, logdir=FLAGS.checkpoint, init_op=init_op, local_init_op=local_init_op, ready_for_local_init_op=ready_for_local_init_op, summary_op=summary_op, saver=saver, summary_writer=writer, recovery_wait_secs=1, global_step=global_step) else: sv = tf.train.Supervisor(is_chief=is_chief, logdir=FLAGS.checkpoint, init_op=init_op, recovery_wait_secs=1, summary_op=summary_op, saver=saver, summary_writer=writer, global_step=global_step) sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, device_filters=[ '/job:ps', '/job:worker/task:%d' % FLAGS.task_index ]) if is_chief: print('Worker %d: Initailizing session...' % FLAGS.task_index) else: print('Worker %d: Waiting for session to be initaialized...' % FLAGS.task_index) sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) print('Worker %d: Session initialization complete.' % FLAGS.task_index) if FLAGS.sync_replicas and is_chief: sess.run(init_token_op) sv.start_queue_runners(sess, [chief_queue_runner]) train_generator = ImageDataGenerator(FLAGS.train_file, horizontal_flip=True, shuffle=True) val_generator = ImageDataGenerator(FLAGS.val_file, shuffle=False) # Get the number of training/validation steps per epoch train_batches_per_epoch = np.floor(train_generator.data_size / FLAGS.batch_size).astype(np.int16) val_batches_per_epoch = np.floor(val_generator.data_size / FLAGS.batch_size).astype(np.int16) print("{} Start training...".format(datetime.now())) print("{} Open Tensorboard at --logdir {}".format( datetime.now(), FLAGS.logdir)) for epoch in range(FLAGS.num_epoches): print("{} Epoch number: {}".format(datetime.now(), epoch + 1)) step = 1 while step < train_batches_per_epoch: start_time = time.time() # Get a batch of images and labels batch_xs, batch_ys = train_generator.next_batch( FLAGS.batch_size) # And run the training op _, loss, gstep = sess.run( [train_op, cross_entropy, global_step], feed_dict={ x: batch_xs, y: batch_ys, keep_prob: FLAGS.dropout }) print('total step: %d, loss: %f' % (gstep, loss)) duration = time.time() - start_time # Generate summary with the current batch of data and write to file if step % FLAGS.display_step == 0: s = sess.run(sv.summary_op, feed_dict={ x: batch_xs, y: batch_ys, keep_prob: 1. }) writer.add_summary(s, epoch * train_batches_per_epoch + step) # print if step % 10 == 0: print("[INFO] {} pics has trained. time using {}".format( step * FLAGS.batch_size, duration)) step += 1 # Validate the model on the entire validation set print("{} Start validation".format(datetime.now())) test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch): batch_tx, batch_ty = val_generator.next_batch(FLAGS.batch_size) acc = sess.run(accuracy, feed_dict={ x: batch_tx, y: batch_ty, keep_prob: 1. }) test_acc += acc test_count += 1 test_acc /= test_count print("Validation Accuracy = {} {}".format(datetime.now(), test_acc)) # Reset the file pointer of t # he image data generator val_generator.reset_pointer() train_generator.reset_pointer() print("{} Saving checkpoint of model...".format(datetime.now())) # save checkpoint of the model checkpoint_name = os.path.join( FLAGS.checkpoint, 'model_epoch' + str(epoch + 1) + '.ckpt') save_path = sv.saver.save(sess, checkpoint_name) print("{} Model checkpoint saved at {}".format( datetime.now(), checkpoint_name))
def fine_tuning(self, train_list, test_list, mean, snapshot, filewriter_path): # Learning params learning_rate = 0.0005 num_epochs = 151 batch_size = 64 # Network params in_img_size = (227, 227) #(height, width) dropout_rate = 1 num_classes = 2 train_layers = ['fc7', 'fc8'] # How often we want to write the tf.summary data to disk display_step = 30 x = tf.placeholder(tf.float32, [batch_size, in_img_size[0], in_img_size[1], 3]) y = tf.placeholder(tf.float32, [None, num_classes]) keep_prob = tf.placeholder(tf.float32) # Initialize model model = alexnet(x, keep_prob, num_classes, train_layers, in_size=in_img_size) #link variable to model output score = model.fc8 # List of trainable variables of the layers we want to train var_list = [ v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers ] # Op for calculating the loss with tf.name_scope("cross_ent"): loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=score, labels=y)) # Train op # Get gradients of all trainable variables gradients = tf.gradients(loss, var_list) gradients = list(zip(gradients, var_list)) ''' # Create optimizer and apply gradient descent to the trainable variables learning_rate = tf.train.exponential_decay(learning_rate, global_step=tf.Variable(0, trainable=False), decay_steps=10,decay_rate=0.9) ''' optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) train_op = optimizer.minimize(loss) # Add gradients to summary for gradient, var in gradients: tf.summary.histogram(var.name + '/gradient', gradient) # Add the variables we train to the summary for var in var_list: tf.summary.histogram(var.name, var) # Add the loss to summary tf.summary.scalar('cross_entropy', loss) # Evaluation op: Accuracy of the model with tf.name_scope("accuracy"): correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # Add the accuracy to the summary tf.summary.scalar('accuracy', accuracy) # Merge all summaries together merged_summary = tf.summary.merge_all() # Initialize the FileWriter writer = tf.summary.FileWriter(filewriter_path) # Initialize an saver for store model checkpoints saver = tf.train.Saver() # Initalize the data generator seperately for the training and validation set train_generator = ImageDataGenerator(train_list, horizontal_flip=True, shuffle=False, mean=mean, scale_size=in_img_size, nb_classes=num_classes) val_generator = ImageDataGenerator(test_list, shuffle=False, mean=mean, scale_size=in_img_size, nb_classes=num_classes) # Get the number of training/validation steps per epoch train_batches_per_epoch = np.floor(train_generator.data_size / batch_size).astype(np.int16) val_batches_per_epoch = np.floor(val_generator.data_size / batch_size).astype(np.int16) # Start Tensorflow session with tf.Session() as sess: # Initialize all variables sess.run(tf.global_variables_initializer()) # Add the model graph to TensorBoard writer.add_graph(sess.graph) # Load the pretrained weights into the non-trainable layer model.load_initial_weights(sess) print("{} Start training...".format(datetime.now())) print("{} Open Tensorboard at --logdir {}".format( datetime.now(), filewriter_path)) # Loop over number of epochs for epoch in range(num_epochs): print("{} Epoch number: {}/{}".format(datetime.now(), epoch + 1, num_epochs)) step = 1 while step < train_batches_per_epoch: # Get a batch of images and labels batch_xs, batch_ys = train_generator.next_batch(batch_size) # And run the training op sess.run(train_op, feed_dict={ x: batch_xs, y: batch_ys, keep_prob: dropout_rate }) # Generate summary with the current batch of data and write to file if step % display_step == 0: s = sess.run(merged_summary, feed_dict={ x: batch_xs, y: batch_ys, keep_prob: 1. }) writer.add_summary( s, epoch * train_batches_per_epoch + step) step += 1 # Validate the model on the entire validation set print("{} Start validation".format(datetime.now())) test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch): batch_tx, batch_ty = val_generator.next_batch(batch_size) acc = sess.run(accuracy, feed_dict={ x: batch_tx, y: batch_ty, keep_prob: 1. }) test_acc += acc test_count += 1 test_acc /= test_count print("{} Validation Accuracy = {:.4f}".format( datetime.now(), test_acc)) # Reset the file pointer of the image data generator val_generator.reset_pointer() train_generator.reset_pointer() print("{} Saving checkpoint of model...".format( datetime.now())) #save checkpoint of the model if epoch % display_step == 0: checkpoint_name = os.path.join( snapshot, 'model_epoch' + str(epoch) + '.ckpt') save_path = saver.save(sess, checkpoint_name) print("{} Model checkpoint saved at {}".format( datetime.now(), checkpoint_name))