def eval_net(net_factory, val_tfrecords, net_size): # Get val data from tfrecords. val_cls_data_label = stat_tfrecords.ReadTFRecord(val_tfrecords, net_size, 3) val_image_batch, val_cls_label_batch, val_reg_label_batch = \ tf.train.batch([val_cls_data_label['image'], val_cls_data_label['cls_label'], val_cls_data_label['reg_label']], batch_size=FLAGS.batch_size, num_threads=16, allow_smaller_final_batch=True) val_cls_prob, val_bbox_pred, _ = net_factory(val_image_batch, is_training=False, mode='VERIFY') # Return eval op eval_cls_op, eval_bbox_pred_op = train_core.compute_accuracy( val_cls_prob, val_bbox_pred, val_cls_label_batch, val_reg_label_batch) return eval_cls_op, eval_bbox_pred_op
def train_net(net_factory, model_prefix, logdir, end_epoch, net_size, tfrecords, val_tfrecords=[], frequent=500): # Set logging verbosity tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ######################################### # Get Detect train data from tfrecords. # ######################################### cls_data_label = stat_tfrecords.ReadTFRecord(tfrecords, net_size, 3) # https://stackoverflow.com/questions/43028683/whats-going-on-in-tf-train-shuffle-batch-and-tf-train-batch?answertab=votes#tab-top # https://stackoverflow.com/questions/34258043/getting-good-mixing-with-many-input-datafiles-in-tensorflow/34258214#34258214 # Different batch_size and capacity and min_after_dequeue impact data selected. image_batch, cls_label_batch, reg_label_batch = \ tf.train.shuffle_batch([cls_data_label['image'], cls_data_label['cls_label'], cls_data_label['reg_label']], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=10000, num_threads=16, allow_smaller_final_batch=True) if len(val_tfrecords): eval_cls_op, eval_bbox_pred_op = eval_net(net_factory, val_tfrecords, net_size) # Network Forward if FLAGS.is_ERC: cls_prob_op, bbox_pred_op, ERC1_loss_op, ERC2_loss_op, cls_loss_op, bbox_loss_op, end_points = \ net_factory(image_batch, cls_label_batch, reg_label_batch) else: cls_prob_op, bbox_pred_op, cls_loss_op, bbox_loss_op, end_points = \ net_factory(image_batch, cls_label_batch, reg_label_batch, mode='TRAIN') ######################################### # Configure the optimization procedure. # ######################################### global_step = tf.Variable(0, trainable=False) boundaries = [int(epoch * FLAGS.image_sum / FLAGS.batch_size) for epoch in LR_EPOCH] lr_values = [FLAGS.lr * (FLAGS.lr_decay_factor ** x) for x in range(0, len(LR_EPOCH) + 1)] lr_op = tf.train.piecewise_constant(global_step, boundaries, lr_values) optimizer = train_core.configure_optimizer(lr_op) if FLAGS.is_ERC: total_loss = ERC1_loss_op + ERC2_loss_op + cls_loss_op + bbox_loss_op else: total_loss = train_core.task_add_weight(cls_loss_op, bbox_loss_op) train_op = optimizer.minimize(total_loss, global_step) ######################################### # Save train/verify summary. # ######################################### tf.summary.scalar('learning_rate', lr_op) tf.summary.scalar('cls_loss', cls_loss_op) tf.summary.scalar('bbox_reg_loss', bbox_loss_op) if FLAGS.is_ERC: tf.summary.scalar('loss_ERC1', ERC1_loss_op) tf.summary.scalar('loss_ERC2', ERC2_loss_op) tf.summary.scalar('loss_last', cls_loss_op) tf.summary.scalar('loss_sum', ERC1_loss_op + ERC2_loss_op + cls_loss_op + bbox_loss_op) else: tf.summary.scalar('loss_sum', cls_loss_op + bbox_loss_op) # Save feature map parameters if FLAGS.is_feature_visual: for feature_name, feature_val in end_points.items(): print(feature_name) tf.summary.histogram(feature_name, feature_val) ######################################### # Check point or retrieve model # ######################################### model_dir = model_prefix.rsplit('/', 1)[0] if not os.path.exists(model_dir): os.makedirs(model_dir) latest_ckpt = tf.train.latest_checkpoint(model_dir) # Adaptive use gpu memory tf_config = tf.ConfigProto() # tf_config.gpu_options.per_process_gpu_memory_fraction = 0.4 tf_config.gpu_options.allow_growth = True sess = tf.Session(config=tf_config) saver = tf.train.Saver() coord = tf.train.Coordinator() if latest_ckpt is not None: saver.restore(sess, latest_ckpt) start_epoch = int(next(re.finditer("(\d+)(?!.*\d)", latest_ckpt)).group(0)) else: sess.run(tf.global_variables_initializer()) start_epoch = 1 threads = tf.train.start_queue_runners(sess=sess, coord=coord) summary_writer = tf.summary.FileWriter(logdir, sess.graph) summary_op = tf.summary.merge_all() ######################################### # Main Training/Verify Loop # ######################################### # allow_smaller_final_batch, avoid discarding data n_step_epoch = int(np.ceil(FLAGS.image_sum / FLAGS.batch_size)) for cur_epoch in range(start_epoch, end_epoch + 1): cls_loss_list = [] bbox_loss_list = [] for batch_idx in range(n_step_epoch): _, cls_pred, bbox_pred, cls_loss, bbox_loss, lr, summary_str, gb_step = \ sess.run([train_op, cls_prob_op, bbox_pred_op, cls_loss_op, bbox_loss_op, lr_op, summary_op, global_step]) cls_loss_list.append(cls_loss) bbox_loss_list.append(bbox_loss) if not batch_idx % frequent: summary_writer.add_summary(summary_str, gb_step) print("%s: Epoch: %d, cls loss: %4f, bbox loss: %4f learning_rate: %4f" % (datetime.now(), cur_epoch, np.mean(cls_loss_list), np.mean(bbox_loss_list), lr)) sys.stdout.flush() print("%s: Epoch: %d, cls loss: %4f, bbox loss: %4f " % (datetime.now(), cur_epoch, np.mean(cls_loss_list), np.mean(bbox_loss_list))) saver.save(sess, model_prefix, cur_epoch) # Computer val set accuracy, using the same as train batch_size if len(val_tfrecords): n_step_val = int(np.ceil(FLAGS.val_image_sum) / FLAGS.batch_size) val_cls_prob = [] val_bbox_error = [] for step in range(n_step_val): val_prob, val_bbox = sess.run([eval_cls_op, eval_bbox_pred_op]) val_cls_prob.append(val_prob) val_bbox_error.append(val_bbox) print("Epoch: %d, cls accuracy: %4f" % (cur_epoch, np.mean(val_cls_prob))) print("Bbox_reg_square_mean_deviation: " + str(np.mean(val_bbox_error, axis=0))) coord.request_stop() coord.join(threads)
def train_attribute_net(net_factory, model_prefix, logdir, end_epoch, net_size, tfrecords, attrib_tfrecords, frequent=500): with tf.Graph().as_default(): ####################################################### # Get Attribute and Detect train data from tfrecords. # ####################################################### cls_data_label_dict = stat_tfrecords.ReadTFRecord( tfrecords, net_size, 3) cls_image_batch, cls_label_batch, reg_label_batch = \ tf.train.shuffle_batch([cls_data_label_dict['image'], cls_data_label_dict['cls_label'], cls_data_label_dict['reg_label']], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=10000, num_threads=16, allow_smaller_final_batch=True) attrib_label_dict = stat_tfrecords.ReadTFRecord( attrib_tfrecords, net_size, 3, vector_name=['head_pose', 'landmark_label'], vector_dim=[3, 68 * 2]) attrib_image_batch, attrib_cls_label_batch, attrib_bbox_reg_label_batch, \ attrib_head_pose_batch, attrib_land_reg_batch = tf.train.shuffle_batch( [attrib_label_dict['image'], attrib_label_dict['cls_label'], attrib_label_dict['reg_label'], attrib_label_dict['head_pose'], attrib_label_dict['landmark_label']], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=10000, num_threads=16, allow_smaller_final_batch=True) images = tf.concat([cls_image_batch, attrib_image_batch], axis=0) cls_labels = tf.concat([cls_label_batch, attrib_cls_label_batch], axis=0) bbox_labels = tf.concat([reg_label_batch, attrib_bbox_reg_label_batch], axis=0) # Network Forward cls_prob_op, bbox_pred_op, head_pose_pred_op, landmark_pred_op, \ cls_loss_op, bbox_loss_op, head_pose_loss_op, landmark_loss_op, end_points = \ net_factory(inputs=images, label=cls_labels, bbox_target=bbox_labels, pose_reg_target=attrib_head_pose_batch, landmark_target=attrib_land_reg_batch) ######################################### # Configure the optimization procedure. # ######################################### global_step = tf.Variable(0, trainable=False) boundaries = [ int(epoch * FLAGS.image_sum / FLAGS.batch_size) for epoch in LR_EPOCH ] lr_values = [ FLAGS.lr * (FLAGS.lr_decay_factor**x) for x in range(0, len(LR_EPOCH) + 1) ] lr_op = tf.train.piecewise_constant(global_step, boundaries, lr_values) optimizer = train_core.configure_optimizer(lr_op) train_op = optimizer.minimize( train_core.task_add_weight(cls_loss_op, bbox_loss_op, landmark_loss_op=landmark_loss_op, pose_loss_op=head_pose_loss_op), global_step) ######################################### # Save train/verify summary. # ######################################### # Save feature map parameters if FLAGS.is_feature_visual: for feature_name, feature_val in end_points.items(): tf.summary.histogram(feature_name, feature_val) tf.summary.scalar('learning_rate', lr_op) tf.summary.scalar('cls_loss', cls_loss_op) tf.summary.scalar('bbox_reg_loss', bbox_loss_op) tf.summary.scalar('head_pose_loss', head_pose_loss_op) tf.summary.scalar('landmark_reg_loss', landmark_loss_op) tf.summary.scalar('cls_bbox_reg_loss_sum', cls_loss_op + bbox_loss_op) ######################################### # Check point or retrieve model # ######################################### model_dir = model_prefix.rsplit('/', 1)[0] if not os.path.exists(model_dir): os.makedirs(model_dir) latest_ckpt = tf.train.latest_checkpoint(model_dir) # Adaptive use gpu memory tf_config = tf.ConfigProto() # tf_config.gpu_options.per_process_gpu_memory_fraction = 0.4 tf_config.gpu_options.allow_growth = True sess = tf.Session(config=tf_config) coord = tf.train.Coordinator() saver = tf.train.Saver() if latest_ckpt is not None: saver.restore(sess, latest_ckpt) start_epoch = int( next(re.finditer("(\d+)(?!.*\d)", latest_ckpt)).group(0)) else: sess.run(tf.global_variables_initializer()) start_epoch = 1 threads = tf.train.start_queue_runners(sess=sess, coord=coord) summary_writer = tf.summary.FileWriter(logdir, sess.graph) summary_op = tf.summary.merge_all() ######################################### # Main Training/Verify Loop # ######################################### # allow_smaller_final_batch, avoid discarding data n_step_epoch = int(np.ceil(FLAGS.image_sum / FLAGS.batch_size)) for cur_epoch in range(start_epoch, end_epoch + 1): cls_loss_list = [] bbox_loss_list = [] landmark_loss_list = [] pose_reg_loss_list = [] for batch_idx in range(n_step_epoch): _, cls_pred, bbox_pred, head_pose_pred, landmark_pred, \ cls_loss, bbox_loss, pose_reg_loss, landmark_loss, \ lr, summary_str, gb_step = sess.run( [train_op, cls_prob_op, bbox_pred_op, head_pose_pred_op, landmark_pred_op, cls_loss_op, bbox_loss_op, head_pose_loss_op, landmark_loss_op, lr_op, summary_op, global_step]) cls_loss_list.append(cls_loss) bbox_loss_list.append(bbox_loss) landmark_loss_list.append(landmark_loss) pose_reg_loss_list.append(pose_reg_loss) if not batch_idx % frequent: summary_writer.add_summary(summary_str, gb_step) print( "%s: Epoch: %d, cls loss: %4f, bbox loss: %4f " "head_pose_loss: %4f landmark_loss: %4f learning_rate: %4f" % (datetime.now(), cur_epoch, np.mean(cls_loss_list), np.mean(bbox_loss_list), np.mean(pose_reg_loss_list), np.mean(landmark_loss_list), lr)) sys.stdout.flush() print( "%s: Epoch: %d, cls loss: %4f, bbox loss: %4f head_pose_loss: %4f landmark_loss: %4f" % (datetime.now(), cur_epoch, np.mean(cls_loss_list), np.mean(bbox_loss_list), np.mean(pose_reg_loss_list), np.mean(landmark_loss_list))) saver.save(sess, model_prefix, cur_epoch) coord.request_stop() coord.join(threads)