def train(net_factory, prefix, end_epoch, base_dir, log_dir, display=200, base_lr=0.01, quantize=True, ckpt=None, optimizer='momentum'): """ train PNet/RNet/ONet :param net_factory: :param prefix: model path :param end_epoch: :param dataset: :param display: :param base_lr: :return: """ print('start training: ....') net = prefix.split('/')[-1] #label file label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net) #label_file = os.path.join(base_dir,'landmark_12_few.txt') print(label_file) f = open(label_file, 'r') # get number of training examples num = len(f.readlines()) print("Total size of the dataset is: ", num) print(prefix) #PNet use this method to get data if net == 'PNet': #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net) dataset_dir = os.path.join(base_dir, 'train_%s_landmark.tfrecord_shuffle' % net) print('dataset dir is:', dataset_dir, 'batch_size = ', config.BATCH_SIZE) image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord( dataset_dir, config.BATCH_SIZE, net, no_landmarks) #RNet use 3 tfrecords to get data else: pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle') part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle') neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle') #landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle') landmark_dir = os.path.join(base_dir, 'landmark_landmark.tfrecord_shuffle') dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir] pos_radio = 1.0 / 6 part_radio = 1.0 / 6 landmark_radio = 1.0 / 6 neg_radio = 3.0 / 6 pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio)) assert pos_batch_size != 0, "Batch Size Error " part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio)) assert part_batch_size != 0, "Batch Size Error " neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio)) assert neg_batch_size != 0, "Batch Size Error " landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio)) assert landmark_batch_size != 0, "Batch Size Error " batch_sizes = [ pos_batch_size, part_batch_size, neg_batch_size, landmark_batch_size ] #print('batch_size is:', batch_sizes) image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords( dataset_dirs, batch_sizes, net, no_landmarks) #landmark_dir if net == 'PNet': image_size = 12 radio_cls_loss = 1.0 radio_bbox_loss = 0.5 radio_landmark_loss = 0.5 elif net == 'RNet': image_size = 24 radio_cls_loss = 1.0 radio_bbox_loss = 0.5 radio_landmark_loss = 0.5 else: radio_cls_loss = 1.0 radio_bbox_loss = 0.5 radio_landmark_loss = 1 image_size = 48 #define placeholder input_image = tf.placeholder( tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image') label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label') bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target') landmark_target = tf.placeholder( tf.float32, shape=[config.BATCH_SIZE, no_landmarks * 2], name='landmark_target') #get loss and accuracy input_image = image_color_distort(input_image) cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op, landmark_pred = net_factory( input_image, label, bbox_target, landmark_target, training=True) #train,update learning rate(3 loss) # count_nan_op = count_nan(landmark_pred) total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op train_op, lr_op = train_model(base_lr, total_loss_op, num, quantize, optimizer) # init sess = tf.Session() #save model saver = tf.train.Saver(max_to_keep=10) step = 0 if ckpt is not None: saver.restore(sess, ckpt) # get last global step step = int(os.path.basename(ckpt).split('-')[1]) print('restored from last step = ', step) else: init = tf.global_variables_initializer() sess.run(init) #visualize some variables tf.summary.scalar("cls_loss", cls_loss_op) #cls_loss tf.summary.scalar("bbox_loss", bbox_loss_op) #bbox_loss tf.summary.scalar("landmark_loss", landmark_loss_op) #landmark_loss tf.summary.scalar("cls_accuracy", accuracy_op) #cls_acc tf.summary.scalar( "total_loss", total_loss_op ) #cls_loss, bbox loss, landmark loss and L2 loss add together summary_op = tf.summary.merge_all() logs_dir = os.path.join(log_dir, net) if os.path.exists(logs_dir) == False: os.mkdir(logs_dir) writer = tf.summary.FileWriter(logs_dir, sess.graph) projector_config = projector.ProjectorConfig() projector.visualize_embeddings(writer, projector_config) #begin coord = tf.train.Coordinator() #begin enqueue thread threads = tf.train.start_queue_runners(sess=sess, coord=coord) # i = 0 #total steps step_per_epoch = int(num / config.BATCH_SIZE + 1) print('step_per_epoch = ', step_per_epoch) MAX_STEP = step_per_epoch * end_epoch epoch = 0 sess.graph.finalize() current_total_loss = 100000 try: for i in range(MAX_STEP): # i = i + 1 # j = i step = step + 1 if coord.should_stop(): break # print ('train step = ', step, image_batch.shape, bbox_batch.shape, landmark_batch.shape) image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run( [image_batch, label_batch, bbox_batch, landmark_batch]) #random flip # print('after batch array') image_batch_array, landmark_batch_array = random_flip_images( image_batch_array, label_batch_array, landmark_batch_array) # print('->>>>> 1') _, _, summary = sess.run( [train_op, lr_op, summary_op], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) if (step + 1) % display == 0: #acc = accuracy(cls_pred, labels_batch) # print('->>>>> 2') cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc, landmark_pred_val = sess.run( [ cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, lr_op, accuracy_op, landmark_pred ], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) if math.isnan(landmark_loss): print('break, landmark loss is nan', landmark_loss) print('landmark pred val ', landmark_pred_val) # nan_count = sess.run([count_nan_op], feed_dict={landmark_pred: landmark_pred_val}) print('no of nan in landmark_pred_val', count_nan(landmark_pred_val)) print('no of nan in landmark_target', count_nan(landmark_batch_array)) # print('other metrics ', square_error_val, k_index_val, valid_inds_val) break total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss + radio_landmark_loss * landmark_loss + L2_loss # landmark loss: %4f, print( "%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (datetime.now(), step + 1, MAX_STEP, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, total_loss, lr)) if total_loss < current_total_loss: current_total_loss = total_loss path_prefix = saver.save(sess, prefix, global_step=step) print('Total loss improved, save model ', path_prefix) # save every end of epochs # if i > 0 and i % step_per_epoch == 0: # path_prefix = saver.save(sess, prefix, global_step=step) # print('Save end of epoch, path prefix is :', path_prefix) writer.add_summary(summary, global_step=step) except tf.errors.OutOfRangeError: print("完成!!!") finally: coord.request_stop() writer.close() coord.join(threads) sess.close()
def train(net_factory, model_save_path, max_epoch, tfrecord_path, display=200, base_lr=0.01): """ train PNet/RNet/ONet :param net_factory: 对应网络的模型 :param model_save_path: 模型参数保存路径 :param max_epoch: 迭代次数 :param tfrecord_path: pos, neg, part, landmark 4类标签数据tfrecord所在路径 :param display: :param base_lr: :return: """ net = model_save_path.split('/')[-1] # label file label_file = tfrecord_path[0] # label_file = os.path.join(base_dir,'landmark_12_few.txt') print(label_file) f = open(label_file, 'r') # get number of training examples num = len(f.readlines()) print("Total size of the dataset is: ", num) print(model_save_path) # PNet use this method to get data if net == 'PNet': print('dataset dir is:', tfrecord_path[1]) image_batch, label_batch, bbox_batch, landmark_batch = \ read_single_tfrecord(tfrecord_path[1], config.BATCH_SIZE, net) # RNet及ONet use 3 tfrecords to get data else: pos_dir = tfrecord_path[1] part_dir = tfrecord_path[2] neg_dir = tfrecord_path[3] landmark_dir = tfrecord_path[4] dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir] pos_radio = 1.0 / 6 part_radio = 1.0 / 6 landmark_radio = 1.0 / 6 neg_radio = 3.0 / 6 pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio)) assert pos_batch_size != 0, "Batch Size Error " part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio)) assert part_batch_size != 0, "Batch Size Error " neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio)) assert neg_batch_size != 0, "Batch Size Error " landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio)) assert landmark_batch_size != 0, "Batch Size Error " batch_sizes = [ pos_batch_size, part_batch_size, neg_batch_size, landmark_batch_size ] # print('batch_size is:', batch_sizes) image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords( dataset_dirs, batch_sizes, net) path_config = PathConfiguration().config if net == 'PNet': image_size = 12 radio_cls_loss = 1.0 radio_bbox_loss = 1.0 radio_landmark_loss = 0.5 logs_dir = path_config.pnet_log_path elif net == 'RNet': image_size = 24 radio_cls_loss = 1.0 radio_bbox_loss = 1.0 radio_landmark_loss = 0.5 logs_dir = path_config.rnet_log_path else: radio_cls_loss = 1.0 radio_bbox_loss = 1.0 radio_landmark_loss = 1.0 image_size = 48 logs_dir = path_config.onet_log_path # define placeholder input_image = tf.placeholder( tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image') label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label') bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target') landmark_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 10], name='landmark_target') # get loss and accuracy input_image = image_color_distort(input_image) cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory( input_image, label, bbox_target, landmark_target, training=True) # train,update learning rate(3 loss) total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op total_loss_op = total_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op train_op, lr_op = train_model(base_lr, total_loss_op, num) # init os.environ["CUDA_VISIBLE_DEVICES"] = config.VISIBLE_GPU # '0,1,2,3' tf.device('/gpu:{}'.format(config.GPU)) init = tf.global_variables_initializer() sess = tf.Session() # save model saver = tf.train.Saver(max_to_keep=0) sess.run(init) # visualize some variables tf.summary.scalar("cls_loss", cls_loss_op) # cls_loss tf.summary.scalar("bbox_loss", bbox_loss_op) # bbox_loss tf.summary.scalar("landmark_loss", landmark_loss_op) # landmark_loss tf.summary.scalar("cls_accuracy", accuracy_op) # cls_acc tf.summary.scalar( "total_loss", total_loss_op ) # cls_loss, bbox loss, landmark loss and L2 loss add together summary_op = tf.summary.merge_all() if not os.path.exists(logs_dir): os.makedirs(logs_dir) writer = tf.summary.FileWriter(logs_dir, sess.graph) projector_config = projector.ProjectorConfig() projector.visualize_embeddings(writer, projector_config) # begin coord = tf.train.Coordinator() # begin enqueue thread threads = tf.train.start_queue_runners(sess=sess, coord=coord) i = 0 # total steps MAX_STEP = int(num / config.BATCH_SIZE + 1) * max_epoch epoch = 0 sess.graph.finalize() try: for step in range(MAX_STEP): i = i + 1 if coord.should_stop(): break image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run( [image_batch, label_batch, bbox_batch, landmark_batch]) # random flip image_batch_array, landmark_batch_array = random_flip_images( image_batch_array, label_batch_array, landmark_batch_array) ''' print('im here') print(image_batch_array.shape) print(label_batch_array.shape) print(bbox_batch_array.shape) print(landmark_batch_array.shape) print(label_batch_array[0]) print(bbox_batch_array[0]) print(landmark_batch_array[0]) ''' _, _, summary = sess.run( [train_op, lr_op, summary_op], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) if (step + 1) % display == 0: cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run( [ cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, lr_op, accuracy_op ], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss total_loss = total_loss + radio_landmark_loss * landmark_loss + L2_loss # landmark loss: %4f, print(( "%s - Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, " + "Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f ") % (datetime.now(), step + 1, MAX_STEP, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, total_loss, lr)) # save every two epochs if i * config.BATCH_SIZE > num * 2: epoch = epoch + 1 i = 0 path_prefix = saver.save(sess, model_save_path, global_step=epoch * 2) print('path prefix is :', path_prefix) writer.add_summary(summary, global_step=step) except tf.errors.OutOfRangeError: print("异常结束!!!") finally: print("完成!!!") coord.request_stop() writer.close() coord.join(threads) sess.close()
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01): """ train PNet/RNet/ONet :param net_factory: :param prefix: model path :param end_epoch: :param dataset: :param display: :param base_lr: :return: """ net = prefix.split('/')[-1] #label file label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net) #label_file = os.path.join(base_dir,'landmark_12_few.txt') print(label_file) f = open(label_file, 'r') # get number of training examples num = len(f.readlines()) print("Total size of the dataset is: ", num) print(prefix) #PNet use this method to get data if net == 'PNet': #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net) dataset_dir = os.path.join(base_dir, 'train_%s_landmark.tfrecord_shuffle' % net) print('dataset dir is:', dataset_dir) image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord( dataset_dir, config.BATCH_SIZE, net) #RNet use 3 tfrecords to get data else: pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle') part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle') neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle') #landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle') landmark_dir = os.path.join('../../DATA/imglists/RNet', 'landmark_landmark.tfrecord_shuffle') dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir] pos_radio = 1.0 / 6 part_radio = 1.0 / 6 landmark_radio = 1.0 / 6 neg_radio = 3.0 / 6 pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio)) assert pos_batch_size != 0, "Batch Size Error " part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio)) assert part_batch_size != 0, "Batch Size Error " neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio)) assert neg_batch_size != 0, "Batch Size Error " landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio)) assert landmark_batch_size != 0, "Batch Size Error " batch_sizes = [ pos_batch_size, part_batch_size, neg_batch_size, landmark_batch_size ] #print('batch_size is:', batch_sizes) image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords( dataset_dirs, batch_sizes, net) #landmark_dir if net == 'PNet': image_size = 12 radio_cls_loss = 1.0 radio_bbox_loss = 0.5 radio_landmark_loss = 0.5 elif net == 'RNet': image_size = 24 radio_cls_loss = 1.0 radio_bbox_loss = 0.5 radio_landmark_loss = 0.5 else: radio_cls_loss = 1.0 radio_bbox_loss = 0.5 radio_landmark_loss = 1 image_size = 48 #define placeholder input_image = tf.placeholder( tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image') label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label') bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target') landmark_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 10], name='landmark_target') #get loss and accuracy input_image = image_color_distort(input_image) cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory( input_image, label, bbox_target, landmark_target, training=True) #train,update learning rate(3 loss) total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op train_op, lr_op = train_model(base_lr, total_loss_op, num) # init init = tf.global_variables_initializer() sess = tf.Session() #save model saver = tf.train.Saver(max_to_keep=0) sess.run(init) #visualize some variables tf.summary.scalar("cls_loss", cls_loss_op) #cls_loss tf.summary.scalar("bbox_loss", bbox_loss_op) #bbox_loss tf.summary.scalar("landmark_loss", landmark_loss_op) #landmark_loss tf.summary.scalar("cls_accuracy", accuracy_op) #cls_acc tf.summary.scalar( "total_loss", total_loss_op ) #cls_loss, bbox loss, landmark loss and L2 loss add together summary_op = tf.summary.merge_all() logs_dir = "./logs/%s" % (net) if os.path.exists(logs_dir) == False: os.mkdir(logs_dir) writer = tf.summary.FileWriter(logs_dir, sess.graph) projector_config = projector.ProjectorConfig() projector.visualize_embeddings(writer, projector_config) #begin coord = tf.train.Coordinator() #begin enqueue thread threads = tf.train.start_queue_runners(sess=sess, coord=coord) i = 0 #total steps MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch epoch = 0 sess.graph.finalize() try: for step in range(MAX_STEP): i = i + 1 if coord.should_stop(): break image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run( [image_batch, label_batch, bbox_batch, landmark_batch]) #random flip image_batch_array, landmark_batch_array = random_flip_images( image_batch_array, label_batch_array, landmark_batch_array) ''' print('im here') print(image_batch_array.shape) print(label_batch_array.shape) print(bbox_batch_array.shape) print(landmark_batch_array.shape) print(label_batch_array[0]) print(bbox_batch_array[0]) print(landmark_batch_array[0]) ''' _, _, summary = sess.run( [train_op, lr_op, summary_op], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) if (step + 1) % display == 0: #acc = accuracy(cls_pred, labels_batch) cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run( [ cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, lr_op, accuracy_op ], feed_dict={ input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array }) total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss + radio_landmark_loss * landmark_loss + L2_loss # landmark loss: %4f, print( "%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (datetime.now(), step + 1, MAX_STEP, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, total_loss, lr)) #save every two epochs if i * config.BATCH_SIZE > num * 2: epoch = epoch + 1 i = 0 path_prefix = saver.save(sess, prefix, global_step=epoch * 2) print('path prefix is :', path_prefix) writer.add_summary(summary, global_step=step) except tf.errors.OutOfRangeError: print("完成!!!") finally: coord.request_stop() writer.close() coord.join(threads) sess.close()