:param max_epoch: 最大训练epoch :param display: 日志打印 :param lr: learning rate :return: """ net_factory = P_Net train(net_factory, model_save_path, max_epoch, tfrecord_path, display=display, base_lr=lr) if __name__ == '__main__': path_config = PathConfiguration().config # tfrecord训练数据路径 tfrecord_folder = [ path_config.pnet_merge_txt_path, path_config.pnet_tfrecord_path_shuffle ] # PNet模型参数保存路径 pnet_model_folder = path_config.pnet_landmark_model_path if not os.path.exists(os.path.dirname(pnet_model_folder)): os.makedirs(os.path.dirname(pnet_model_folder)) train_PNet(tfrecord_folder, pnet_model_folder, max_epoch=50, display=100, lr=0.001)
def train(net_factory, model_save_path, max_epoch, tfrecord_path, display=200, base_lr=0.01): """ train PNet/RNet/ONet :param net_factory: 对应网络的模型 :param model_save_path: 模型参数保存路径 :param max_epoch: 迭代次数 :param tfrecord_path: pos, neg, part, landmark 4类标签数据tfrecord所在路径 :param display: :param base_lr: :return: """ net = model_save_path.split('/')[-1] # label file label_file = tfrecord_path[0] # label_file = os.path.join(base_dir,'landmark_12_few.txt') print(label_file) f = open(label_file, 'r') # get number of training examples num = len(f.readlines()) print("Total size of the dataset is: ", num) print(model_save_path) # PNet use this method to get data if net == 'PNet': print('dataset dir is:', tfrecord_path[1]) image_batch, label_batch, bbox_batch, landmark_batch = \ read_single_tfrecord(tfrecord_path[1], config.BATCH_SIZE, net) # RNet及ONet use 3 tfrecords to get data else: pos_dir = tfrecord_path[1] part_dir = tfrecord_path[2] neg_dir = tfrecord_path[3] landmark_dir = tfrecord_path[4] dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir] pos_radio = 1.0 / 6 part_radio = 1.0 / 6 landmark_radio = 1.0 / 6 neg_radio = 3.0 / 6 pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio)) assert pos_batch_size != 0, "Batch Size Error " part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio)) assert part_batch_size != 0, "Batch Size Error " neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio)) assert neg_batch_size != 0, "Batch Size Error " landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio)) assert landmark_batch_size != 0, "Batch Size Error " batch_sizes = [ pos_batch_size, part_batch_size, neg_batch_size, landmark_batch_size ] # print('batch_size is:', batch_sizes) image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords( dataset_dirs, batch_sizes, net) path_config = PathConfiguration().config if net == 'PNet': image_size = 12 radio_cls_loss = 1.0 radio_bbox_loss = 1.0 radio_landmark_loss = 0.5 logs_dir = path_config.pnet_log_path elif net == 'RNet': image_size = 24 radio_cls_loss = 1.0 radio_bbox_loss = 1.0 radio_landmark_loss = 0.5 logs_dir = path_config.rnet_log_path else: radio_cls_loss = 1.0 radio_bbox_loss = 1.0 radio_landmark_loss = 1.0 image_size = 48 logs_dir = path_config.onet_log_path # define placeholder input_image = tf.placeholder( tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image') label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label') bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target') landmark_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 10], name='landmark_target') # get loss and accuracy input_image = image_color_distort(input_image) cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory( input_image, label, bbox_target, landmark_target, training=True) # train,update learning rate(3 loss) total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op total_loss_op = total_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op train_op, lr_op = train_model(base_lr, total_loss_op, num) # init os.environ["CUDA_VISIBLE_DEVICES"] = config.VISIBLE_GPU # '0,1,2,3' tf.device('/gpu:{}'.format(config.GPU)) init = tf.global_variables_initializer() sess = tf.Session() # save model saver = tf.train.Saver(max_to_keep=0) sess.run(init) # visualize some variables tf.summary.scalar("cls_loss", cls_loss_op) # cls_loss tf.summary.scalar("bbox_loss", bbox_loss_op) # bbox_loss tf.summary.scalar("landmark_loss", landmark_loss_op) # landmark_loss tf.summary.scalar("cls_accuracy", accuracy_op) # cls_acc tf.summary.scalar( "total_loss", total_loss_op ) # cls_loss, bbox loss, landmark loss and L2 loss add together summary_op = tf.summary.merge_all() if not os.path.exists(logs_dir): os.makedirs(logs_dir) writer = tf.summary.FileWriter(logs_dir, sess.graph) projector_config = projector.ProjectorConfig() projector.visualize_embeddings(writer, projector_config) # begin coord = tf.train.Coordinator() # begin enqueue thread threads = tf.train.start_queue_runners(sess=sess, coord=coord) i = 0 # total steps MAX_STEP = int(num / config.BATCH_SIZE + 1) * max_epoch epoch = 0 sess.graph.finalize() try: for step in range(MAX_STEP): i = i + 1 if coord.should_stop(): break image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run( [image_batch, label_batch, bbox_batch, landmark_batch]) # random flip image_batch_array, landmark_batch_array = random_flip_images( image_batch_array, label_batch_array, landmark_batch_array) ''' print('im here') print(image_batch_array.shape) print(label_batch_array.shape) print(bbox_batch_array.shape) print(landmark_batch_array.shape) print(label_batch_array[0]) print(bbox_batch_array[0]) print(landmark_batch_array[0]) ''' _, _, summary = sess.run( [train_op, lr_op, summary_op], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) if (step + 1) % display == 0: cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run( [ cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, lr_op, accuracy_op ], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss total_loss = total_loss + radio_landmark_loss * landmark_loss + L2_loss # landmark loss: %4f, print(( "%s - Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, " + "Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f ") % (datetime.now(), step + 1, MAX_STEP, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, total_loss, lr)) # save every two epochs if i * config.BATCH_SIZE > num * 2: epoch = epoch + 1 i = 0 path_prefix = saver.save(sess, model_save_path, global_step=epoch * 2) print('path prefix is :', path_prefix) writer.add_summary(summary, global_step=step) except tf.errors.OutOfRangeError: print("异常结束!!!") finally: print("完成!!!") coord.request_stop() writer.close() coord.join(threads) sess.close()