def train_multi_label_tfRecords( train_file_tfRecord, valid_file_tfRecord, train_dir, num_classes, batch_size, arch_model, learning_r_decay, learning_rate_base, decay_rate, dropout_prob, epoch, height, width, checkpoint_exclude_scopes, early_stop, EARLY_STOP_PATIENCE, fine_tune, train_all_layers, checkpoint_path, train_n, valid_n, g_parameter): # ---------------------------------------------------------------------------------# print(train_file_tfRecord) train_image_filename_queue = tf.train.string_input_producer( tf.train.match_filenames_once(train_file_tfRecord)) train_images, train_labels = read_image_batch(train_image_filename_queue, num_classes, batch_size, height, width) valid_image_filename_queue = tf.train.string_input_producer( tf.train.match_filenames_once(valid_file_tfRecord)) valid_images, valid_labels = read_image_batch(valid_image_filename_queue, num_classes, batch_size, height, width) # ---------------------------------------------------------------------------------# X, Y, is_train, keep_prob_fc = input_placeholder(height, width, num_classes) net, _ = build_net_multi_label(X, num_classes, keep_prob_fc, is_train, arch_model) variables_to_restore, variables_to_train = g_parameter( checkpoint_exclude_scopes) loss = cost(Y, net) global_step = tf.Variable(0, trainable=False) if learning_r_decay: learning_rate = tf.train.exponential_decay(learning_rate_base, global_step * batch_size, train_n, decay_rate, staircase=True) else: learning_rate = learning_rate_base if train_all_layers: variables_to_train = [] optimizer = train_op(learning_rate, loss, variables_to_train, global_step) pre = tf.nn.sigmoid(net) accuracy = model_mAP(pre, Y) #------------------------------------------------------------------------------------# sess = tf.InteractiveSession() tf.local_variables_initializer().run() tf.global_variables_initializer().run() #init = tf.global_variables_initializer() #sess.run(init) saver2 = tf.train.Saver(tf.global_variables()) if not train_all_layers: saver_net = tf.train.Saver(variables_to_restore) saver_net.restore(sess, checkpoint_path) if fine_tune: # saver2.restore(sess, fine_tune_dir) latest = tf.train.latest_checkpoint(train_dir) if not latest: print("No checkpoint to continue from in", train_dir) sys.exit(1) print("resume", latest) saver2.restore(sess, latest) # early stopping best_valid = np.inf best_valid_epoch = 0 # start queue runner coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for epoch_i in range(epoch): for batch_i in range(int(train_n / batch_size)): train_image = tf.image.resize_images(train_images, (height, width), method=0) train_image = (tf.cast(train_image, tf.float32) / 255 - 0.5) * 2 train_image, train_label = sess.run([train_image, train_labels]) # print (train_image) # print (train_label) los, _ = sess.run( [loss, optimizer], feed_dict={ X: train_image, Y: train_label, is_train: True, keep_prob_fc: dropout_prob }) # print (los) #checkpoint_path = os.path.join(train_dir, 'model.ckpt') #saver2.save(sess, checkpoint_path, global_step=batch_i, write_meta_graph=False) if batch_i % 100 == 0: loss_, acc_ = sess.run( [loss, accuracy], feed_dict={ X: train_image, Y: train_label, is_train: False, keep_prob_fc: 1.0 }) print( 'Batch: {:>2}: Training loss: {:>3.5f}, Training mAP: {:>3.5f}' .format(batch_i, loss_, acc_)) if batch_i % 500 == 0: valid_image = tf.image.resize_images(valid_images, (height, width), method=0) valid_image = (tf.cast(valid_image, tf.float32) / 255 - 0.5) * 2 valid_image, valid_label = sess.run( [valid_image, valid_labels]) ls, acc = sess.run( [loss, accuracy], feed_dict={ X: valid_image, Y: valid_label, is_train: False, keep_prob_fc: 1.0 }) print( 'Batch: {:>2}: Validation loss: {:>3.5f}, Validation mAP: {:>3.5f}' .format(batch_i, ls, acc)) if batch_i % 500 == 0: checkpoint_path = os.path.join(train_dir, 'model.ckpt') saver2.save(sess, checkpoint_path, global_step=batch_i, write_meta_graph=False) print( 'Epoch===================================>: {:>2}'.format(epoch_i)) valid_ls = 0 valid_acc = 0 for batch_i in range(int(valid_n / batch_size)): valid_image = tf.image.resize_images(valid_images, (height, width), method=0) valid_image = (tf.cast(valid_image, tf.float32) / 255 - 0.5) * 2 valid_image, valid_label = sess.run([valid_image, valid_labels]) # images_valid, labels_valid = get_next_batch_from_path(valid_image, valid_label, batch_i, height, width, batch_size=batch_size, training=False) epoch_ls, epoch_acc = sess.run([loss, accuracy], feed_dict={ X: valid_image, Y: valid_label, keep_prob_fc: 1.0, is_train: False }) valid_ls = valid_ls + epoch_ls valid_acc = valid_acc + epoch_acc print( 'Epoch: {:>2}: Validation loss: {:>3.5f}, Validation mAP: {:>3.5f}' .format(epoch_i, valid_ls / int(valid_n / batch_size), valid_acc / int(valid_n / batch_size))) if valid_acc / int(valid_n / batch_size) > 0.90: checkpoint_path = os.path.join(train_dir, 'model.ckpt') saver2.save(sess, checkpoint_path, global_step=epoch_i, write_meta_graph=False) # ---------------------------------------------------------------------------------# if early_stop: loss_valid = valid_ls / int(valid_n / batch_size) if loss_valid < best_valid: best_valid = loss_valid best_valid_epoch = epoch_i elif best_valid_epoch + EARLY_STOP_PATIENCE < epoch_i: print("Early stopping.") print("Best valid loss was {:.6f} at epoch {}.".format( best_valid, best_valid_epoch)) break # train_data, train_label = shuffle_train_data(train_data, train_label) # stop queue runner coord.request_stop() coord.join(threads) sess.close()
batch_size = config.batch_size sample_dir = 'val.txt' train_rate = 1.0 test_data, test_label, valid_data, valid_label, test_n, valid_n, note_label = data_load_from_txt_mullabel( sample_dir, train_rate).gen_train_valid() print('test_n', test_n) print('valid_n', valid_n) # test_label = np_utils.to_categorical(test_label, num_classes) # valid_label = np_utils.to_categorical(valid_label, num_classes) X, Y, is_train, keep_prob_fc = input_placeholder(height, width, num_classes) net, _ = build_net_multi_label(X, num_classes, keep_prob_fc, is_train, arch_model) loss = cost(Y, net) pre = tf.nn.sigmoid(net) accuracy = model_mAP(pre, Y) # predict = tf.reshape(net, [-1, num_classes], name='predictions') if __name__ == '__main__': train_dir = 'model' latest = tf.train.latest_checkpoint(train_dir) if not latest: print("No checkpoint to continue from in", train_dir) sys.exit(1) print("resume", latest) sess = tf.Session() saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, latest) test_ls = 0 test_acc = 0 for batch_i in range(int(test_n / batch_size)):
def train_multi_label(train_data, train_label, valid_data, valid_label, train_dir, num_classes, batch_size, arch_model, learning_r_decay, learning_rate_base, decay_rate, dropout_prob, epoch, height, width, checkpoint_exclude_scopes, early_stop, EARLY_STOP_PATIENCE, fine_tune, train_all_layers, checkpoint_path, train_n, valid_n, g_parameter): # ---------------------------------------------------------------------------------# X, Y, is_train, keep_prob_fc = input_placeholder(height, width, num_classes) net, _ = build_net_multi_label(X, num_classes, keep_prob_fc, is_train, arch_model) variables_to_restore, variables_to_train = g_parameter( checkpoint_exclude_scopes) loss = cost(Y, net) global_step = tf.Variable(0, trainable=False) if learning_r_decay: learning_rate = tf.train.exponential_decay(learning_rate_base, global_step * batch_size, train_n, decay_rate, staircase=True) else: learning_rate = learning_rate_base if train_all_layers: variables_to_train = [] optimizer = train_op(learning_rate, loss, variables_to_train, global_step) accuracy = model_mAP(net, Y) #------------------------------------------------------------------------------------# sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) saver2 = tf.train.Saver(tf.global_variables()) if not train_all_layers: saver_net = tf.train.Saver(variables_to_restore) saver_net.restore(sess, checkpoint_path) if fine_tune: # saver2.restore(sess, fine_tune_dir) latest = tf.train.latest_checkpoint(train_dir) if not latest: print("No checkpoint to continue from in", train_dir) sys.exit(1) print("resume", latest) saver2.restore(sess, latest) # early stopping best_valid = np.inf best_valid_epoch = 0 for epoch_i in range(epoch): for batch_i in range(int(train_n / batch_size)): images, labels = get_next_batch_from_path(train_data, train_label, batch_i, height, width, batch_size=batch_size, training=True) los, _ = sess.run([loss, optimizer], feed_dict={ X: images, Y: labels, is_train: True, keep_prob_fc: dropout_prob }) print(los) checkpoint_path = os.path.join(train_dir, 'model.ckpt') saver2.save(sess, checkpoint_path, global_step=batch_i, write_meta_graph=False) if batch_i % 20 == 0: loss_, acc_ = sess.run([loss, accuracy], feed_dict={ X: images, Y: labels, is_train: False, keep_prob_fc: 1.0 }) print( 'Batch: {:>2}: Training loss: {:>3.5f}, Training mAP: {:>3.5f}' .format(batch_i, loss_, acc_)) if batch_i % 100 == 0: images, labels = get_next_batch_from_path( valid_data, valid_label, batch_i % (int(valid_n / batch_size)), height, width, batch_size=batch_size, training=False) ls, acc = sess.run([loss, accuracy], feed_dict={ X: images, Y: labels, is_train: False, keep_prob_fc: 1.0 }) print( 'Batch: {:>2}: Validation loss: {:>3.5f}, Validation mAP: {:>3.5f}' .format(batch_i, ls, acc)) print( 'Epoch===================================>: {:>2}'.format(epoch_i)) valid_ls = 0 valid_acc = 0 for batch_i in range(int(valid_n / batch_size)): images_valid, labels_valid = get_next_batch_from_path( valid_data, valid_label, batch_i, height, width, batch_size=batch_size, training=False) epoch_ls, epoch_acc = sess.run( [loss, accuracy], feed_dict={ X: images_valid, Y: labels_valid, keep_prob_fc: 1.0, is_train: False }) valid_ls = valid_ls + epoch_ls valid_acc = valid_acc + epoch_acc print( 'Epoch: {:>2}: Validation loss: {:>3.5f}, Validation mAP: {:>3.5f}' .format(epoch_i, valid_ls / int(valid_n / batch_size), valid_acc / int(valid_n / batch_size))) if valid_acc / int(valid_n / batch_size) > 0.90: checkpoint_path = os.path.join(train_dir, 'model.ckpt') saver2.save(sess, checkpoint_path, global_step=epoch_i, write_meta_graph=False) # ---------------------------------------------------------------------------------# if early_stop: loss_valid = valid_ls / int(valid_n / batch_size) if loss_valid < best_valid: best_valid = loss_valid best_valid_epoch = epoch_i elif best_valid_epoch + EARLY_STOP_PATIENCE < epoch_i: print("Early stopping.") print("Best valid loss was {:.6f} at epoch {}.".format( best_valid, best_valid_epoch)) break train_data, train_label = shuffle_train_data(train_data, train_label) sess.close()