def train_lanenet(dataset_dir, weights_path=None, net_flag='vgg'): """ :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ train_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder( dataset_dir=dataset_dir, flags='train') val_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder( dataset_dir=dataset_dir, flags='val') with tf.device('/gpu:1'): # set lanenet train_net = lanenet.LaneNet(net_flag=net_flag, phase='train', reuse=False) val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True) # set compute graph node for training train_images, train_binary_labels, train_instance_labels = train_dataset.inputs( CFG.TRAIN.BATCH_SIZE, 1) train_compute_ret = train_net.compute_loss( input_tensor=train_images, binary_label=train_binary_labels, instance_label=train_instance_labels, name='lanenet_model') train_total_loss = train_compute_ret['total_loss'] train_binary_seg_loss = train_compute_ret['binary_seg_loss'] train_disc_loss = train_compute_ret['discriminative_loss'] train_pix_embedding = train_compute_ret['instance_seg_logits'] train_prediction_logits = train_compute_ret['binary_seg_logits'] train_prediction_score = tf.nn.softmax(logits=train_prediction_logits) train_prediction = tf.argmax(train_prediction_score, axis=-1) train_accuracy = evaluate_model_utils.calculate_model_precision( train_compute_ret['binary_seg_logits'], train_binary_labels) train_fp = evaluate_model_utils.calculate_model_fp( train_compute_ret['binary_seg_logits'], train_binary_labels) train_fn = evaluate_model_utils.calculate_model_fn( train_compute_ret['binary_seg_logits'], train_binary_labels) train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_prediction) train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_pix_embedding) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=train_total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=train_accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=train_binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=train_disc_loss) train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn) train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp) train_binary_seg_ret_img = tf.summary.image( name='train_binary_seg_ret', tensor=train_binary_seg_ret_for_summary) train_embedding_feats_ret_img = tf.summary.image( name='train_embedding_feats_ret', tensor=train_embedding_ret_for_summary) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_cost_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar, train_fn_scalar, train_fp_scalar, train_binary_seg_ret_img, train_embedding_feats_ret_img ]) # set compute graph node for validation val_images, val_binary_labels, val_instance_labels = val_dataset.inputs( CFG.TRAIN.VAL_BATCH_SIZE, 1) val_compute_ret = val_net.compute_loss( input_tensor=val_images, binary_label=val_binary_labels, instance_label=val_instance_labels, name='lanenet_model') val_total_loss = val_compute_ret['total_loss'] val_binary_seg_loss = val_compute_ret['binary_seg_loss'] val_disc_loss = val_compute_ret['discriminative_loss'] val_pix_embedding = val_compute_ret['instance_seg_logits'] val_prediction_logits = val_compute_ret['binary_seg_logits'] val_prediction_score = tf.nn.softmax(logits=val_prediction_logits) val_prediction = tf.argmax(val_prediction_score, axis=-1) val_accuracy = evaluate_model_utils.calculate_model_precision( val_compute_ret['binary_seg_logits'], val_binary_labels) val_fp = evaluate_model_utils.calculate_model_fp( val_compute_ret['binary_seg_logits'], val_binary_labels) val_fn = evaluate_model_utils.calculate_model_fn( val_compute_ret['binary_seg_logits'], val_binary_labels) val_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary( img=val_prediction) val_embedding_ret_for_summary = evaluate_model_utils.get_image_summary( img=val_pix_embedding) val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=val_total_loss) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=val_accuracy) val_binary_seg_loss_scalar = tf.summary.scalar( name='val_binary_seg_loss', tensor=val_binary_seg_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=val_disc_loss) val_fn_scalar = tf.summary.scalar(name='val_fn', tensor=val_fn) val_fp_scalar = tf.summary.scalar(name='val_fp', tensor=val_fp) val_binary_seg_ret_img = tf.summary.image( name='val_binary_seg_ret', tensor=val_binary_seg_ret_for_summary) val_embedding_feats_ret_img = tf.summary.image( name='val_embedding_feats_ret', tensor=val_embedding_ret_for_summary) val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar, val_fn_scalar, val_fp_scalar, val_binary_seg_ret_img, val_embedding_feats_ret_img ]) # set optimizer global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.polynomial_decay( learning_rate=CFG.TRAIN.LEARNING_RATE, global_step=global_step, decay_steps=CFG.TRAIN.EPOCHS, power=0.9) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=CFG.TRAIN.MOMENTUM).minimize( loss=train_total_loss, var_list=tf.trainable_variables(), global_step=global_step) # Set tf model save path model_save_dir = 'model/tusimple_lanenet_{:s}'.format(net_flag) if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format( net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) saver = tf.train.Saver() # Set tf summary save path tboard_save_path = 'tboard/tusimple_lanenet_{:s}'.format(net_flag) if not ops.exists(tboard_save_path): os.makedirs(tboard_save_path) # Set sess configuration sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # Set the training parameters train_epochs = CFG.TRAIN.EPOCHS log.info('Global configuration is as follows:') log.info(CFG) with sess.as_default(): if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) if net_flag == 'vgg' and weights_path is None: load_pretrained_weights(tf.trainable_variables(), './data/vgg16.npy', sess) train_cost_time_mean = [] for epoch in range(train_epochs): # training part t_start = time.time() _, train_c, train_accuracy_figure, train_fn_figure, train_fp_figure, lr, train_summary, train_binary_loss, \ train_instance_loss, train_embeddings, train_binary_seg_imgs, train_gt_imgs, \ train_binary_gt_labels, train_instance_gt_labels = \ sess.run([optimizer, train_total_loss, train_accuracy, train_fn, train_fp, learning_rate, train_merge_summary_op, train_binary_seg_loss, train_disc_loss, train_pix_embedding, train_prediction, train_images, train_binary_labels, train_instance_labels]) if math.isnan(train_c) or math.isnan( train_binary_loss) or math.isnan(train_instance_loss): log.error('cost is: {:.5f}'.format(train_c)) log.error('binary cost is: {:.5f}'.format(train_binary_loss)) log.error( 'instance cost is: {:.5f}'.format(train_instance_loss)) return if epoch % 100 == 0: record_training_intermediate_result( gt_images=train_gt_imgs, gt_binary_labels=train_binary_gt_labels, gt_instance_labels=train_instance_gt_labels, binary_seg_images=train_binary_seg_imgs, pix_embeddings=train_embeddings) summary_writer.add_summary(summary=train_summary, global_step=epoch) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: log.info( 'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}' ' lr= {:6f} mean_cost_time= {:5f}s '.format( epoch + 1, train_c, train_binary_loss, train_instance_loss, train_accuracy_figure, train_fp_figure, train_fn_figure, lr, np.mean(train_cost_time_mean))) del train_cost_time_mean[:] # validation part val_c, val_accuracy_figure, val_fn_figure, val_fp_figure, val_summary, val_binary_loss, \ val_instance_loss, val_embeddings, val_binary_seg_imgs, val_gt_imgs, \ val_binary_gt_labels, val_instance_gt_labels = \ sess.run([val_total_loss, val_accuracy, val_fn, val_fp, val_merge_summary_op, val_binary_seg_loss, val_disc_loss, val_pix_embedding, val_prediction, val_images, val_binary_labels, val_instance_labels]) if math.isnan(val_c) or math.isnan(val_binary_loss) or math.isnan( val_instance_loss): log.error('cost is: {:.5f}'.format(val_c)) log.error('binary cost is: {:.5f}'.format(val_binary_loss)) log.error('instance cost is: {:.5f}'.format(val_instance_loss)) return if epoch % 100 == 0: record_training_intermediate_result( gt_images=val_gt_imgs, gt_binary_labels=val_binary_gt_labels, gt_instance_labels=val_instance_gt_labels, binary_seg_images=val_binary_seg_imgs, pix_embeddings=val_embeddings, flag='val') cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) summary_writer.add_summary(summary=val_summary, global_step=epoch) if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0: log.info( 'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}' ' mean_cost_time= {:5f}s '.format( epoch + 1, val_c, val_binary_loss, val_instance_loss, val_accuracy_figure, val_fp_figure, val_fn_figure, np.mean(train_cost_time_mean))) del train_cost_time_mean[:] if epoch % 2000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=global_step) return
train_total_loss = train_compute_ret['total_loss'] train_binary_seg_loss = train_compute_ret['binary_seg_loss'] train_disc_loss = train_compute_ret['discriminative_loss'] train_pix_embedding = train_compute_ret['instance_seg_logits'] train_prediction_logits = train_compute_ret['binary_seg_logits'] train_prediction_score = tf.nn.softmax(logits=train_prediction_logits) train_prediction = tf.argmax(train_prediction_score, axis=-1) train_accuracy = evaluate_model_utils.calculate_model_precision( train_compute_ret['binary_seg_logits'], train_binary_labels) train_fp = evaluate_model_utils.calculate_model_fp( train_compute_ret['binary_seg_logits'], train_binary_labels) train_fn = evaluate_model_utils.calculate_model_fn( train_compute_ret['binary_seg_logits'], train_binary_labels) train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_prediction) train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_pix_embedding) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=train_total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=train_accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=train_binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=train_disc_loss) train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn) train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp) train_binary_seg_ret_img = tf.summary.image( name='train_binary_seg_ret', tensor=train_binary_seg_ret_for_summary)
def train_lanenet(weights_path=None, net_flag='vgg', version_flag='', scratch=False): """ :param weights_path: :param net_flag: choose which base network to use :param version_flag: exp flag :return: """ # ========================== placeholder ========================= # with tf.name_scope('train_input'): train_input_tensor = tf.placeholder(dtype=tf.float32, name='input_image', shape=[None, None, None, 3]) train_binary_label_tensor = tf.placeholder(dtype=tf.float32, name='binary_input_label', shape=[None, None, None, 1]) train_instance_label_tensor = tf.placeholder( dtype=tf.float32, name='instance_input_label', shape=[None, None, None, 1]) with tf.name_scope('val_input'): val_input_tensor = tf.placeholder(dtype=tf.float32, name='input_image', shape=[None, None, None, 3]) val_binary_label_tensor = tf.placeholder(dtype=tf.float32, name='binary_input_label', shape=[None, None, None, 1]) val_instance_label_tensor = tf.placeholder(dtype=tf.float32, name='instance_input_label', shape=[None, None, None, 1]) # ================================================================ # # Define Network # # ================================================================ # train_net = lanenet.LaneNet(net_flag=net_flag, phase='train', reuse=tf.AUTO_REUSE) val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True) # ---------------------------------------------------------------- # # ================================================================ # # Train Input & Output # # ================================================================ # trainset = DataSet('train') # trainset = MergeDataSet('train_lane') train_compute_ret = train_net.compute_loss( input_tensor=train_input_tensor, binary_label=train_binary_label_tensor, instance_label=train_instance_label_tensor, name='lanenet_model') train_total_loss = train_compute_ret['total_loss'] train_binary_seg_loss = train_compute_ret['binary_seg_loss'] # 语义分割 loss train_disc_loss = train_compute_ret[ 'discriminative_loss'] # embedding loss train_pix_embedding = train_compute_ret[ 'instance_seg_logits'] # embedding feature, HxWxN train_l2_reg_loss = train_compute_ret['l2_reg_loss'] train_prediction_logits = train_compute_ret[ 'binary_seg_logits'] # 语义分割结果,HxWx2 train_prediction_score = tf.nn.softmax(logits=train_prediction_logits) train_prediction = tf.argmax(train_prediction_score, axis=-1) # 语义分割二值图 train_accuracy = evaluate_model_utils.calculate_model_precision( train_compute_ret['binary_seg_logits'], train_binary_label_tensor) train_fp = evaluate_model_utils.calculate_model_fp( train_compute_ret['binary_seg_logits'], train_binary_label_tensor) train_fn = evaluate_model_utils.calculate_model_fn( train_compute_ret['binary_seg_logits'], train_binary_label_tensor) train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_prediction) # (I - min) * 255 / (max -min), 归一化到0-255 train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_pix_embedding) # (I - min) * 255 / (max -min), 归一化到0-255 # ---------------------------------------------------------------- # # ================================================================ # # Define Optimizer # # ================================================================ # # set optimizer global_step = tf.Variable(0, trainable=False, name='global_step') # learning_rate = tf.train.cosine_decay_restarts( # 余弦衰减 # learning_rate=cfg.TRAIN.LEARNING_RATE, # 初始学习率 # global_step=global_step, # 当前迭代次数 # first_decay_steps=cfg.TRAIN.STEPS/3, # 首次衰减周期 # t_mul=2.0, # 随后每次衰减周期倍数 # m_mul=1.0, # 随后每次初始学习率倍数 # alpha = 0.1, # 最小的学习率=alpha*learning_rate # ) learning_rate = tf.train.polynomial_decay( # 多项式衰减 learning_rate=cfg.TRAIN.LEARNING_RATE, # 初始学习率 global_step=global_step, # 当前迭代次数 decay_steps=cfg.TRAIN.STEPS / 4, # 在迭代到该次数实际,学习率衰减为 learning_rate * dacay_rate end_learning_rate=cfg.TRAIN.LEARNING_RATE / 10, # 最小的学习率 power=0.9, cycle=True) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # for batch normalization with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=cfg.TRAIN.MOMENTUM).minimize( loss=train_total_loss, var_list=tf.trainable_variables(), global_step=global_step) # ---------------------------------------------------------------- # # ================================================================ # # Train Summary # # ================================================================ # train_loss_scalar = tf.summary.scalar(name='train_cost', tensor=train_total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=train_accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=train_binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=train_disc_loss) train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn) train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp) train_binary_seg_ret_img = tf.summary.image( name='train_binary_seg_ret', tensor=train_binary_seg_ret_for_summary) train_embedding_feats_ret_img = tf.summary.image( name='train_embedding_feats_ret', tensor=train_embedding_ret_for_summary) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_loss_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar, train_fn_scalar, train_fp_scalar, train_binary_seg_ret_img, train_embedding_feats_ret_img, learning_rate_scalar ]) # ---------------------------------------------------------------- # # ================================================================ # # Val Input & Output # # ================================================================ # valset = DataSet('val', net_flag) # valset = MergeDataSet('test_lane') val_compute_ret = val_net.compute_loss( input_tensor=val_input_tensor, binary_label=val_binary_label_tensor, instance_label=val_instance_label_tensor, name='lanenet_model') val_total_loss = val_compute_ret['total_loss'] val_binary_seg_loss = val_compute_ret['binary_seg_loss'] val_disc_loss = val_compute_ret['discriminative_loss'] val_pix_embedding = val_compute_ret['instance_seg_logits'] val_prediction_logits = val_compute_ret['binary_seg_logits'] val_prediction_score = tf.nn.softmax(logits=val_prediction_logits) val_prediction = tf.argmax(val_prediction_score, axis=-1) val_accuracy = evaluate_model_utils.calculate_model_precision( val_compute_ret['binary_seg_logits'], val_binary_label_tensor) val_fp = evaluate_model_utils.calculate_model_fp( val_compute_ret['binary_seg_logits'], val_binary_label_tensor) val_fn = evaluate_model_utils.calculate_model_fn( val_compute_ret['binary_seg_logits'], val_binary_label_tensor) val_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary( img=val_prediction) val_embedding_ret_for_summary = evaluate_model_utils.get_image_summary( img=val_pix_embedding) # ---------------------------------------------------------------- # # ================================================================ # # VAL Summary # # ================================================================ # val_loss_scalar = tf.summary.scalar(name='val_cost', tensor=val_total_loss) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=val_accuracy) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=val_binary_seg_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=val_disc_loss) val_fn_scalar = tf.summary.scalar(name='val_fn', tensor=val_fn) val_fp_scalar = tf.summary.scalar(name='val_fp', tensor=val_fp) val_binary_seg_ret_img = tf.summary.image( name='val_binary_seg_ret', tensor=val_binary_seg_ret_for_summary) val_embedding_feats_ret_img = tf.summary.image( name='val_embedding_feats_ret', tensor=val_embedding_ret_for_summary) val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_loss_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar, val_fn_scalar, val_fp_scalar, val_binary_seg_ret_img, val_embedding_feats_ret_img ]) # ---------------------------------------------------------------- # # ================================================================ # # Config Saver & Session # # ================================================================ # # Set tf model save path model_save_dir = 'model/tusimple_lanenet_{:s}_{:s}'.format( net_flag, version_flag) os.makedirs(model_save_dir, exist_ok=True) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format( net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # ============================== if scratch: """ 删除 Momentum 的参数, 注意这里保存的 meta 文件也会删了 tensorflow 在 save model 的时候,如果选择了 global_step 选项,会 global_step 值也保存下来, 然后 restore 的时候也就会接着这个 global_step 继续训练下去,因此需要去掉 """ variables = tf.contrib.framework.get_variables_to_restore() variables_to_resotre = [ v for v in variables if 'Momentum' not in v.name.split('/')[-1] ] variables_to_resotre = [ v for v in variables_to_resotre if 'global_step' not in v.name.split('/')[-1] ] # remove global step restore_saver = tf.train.Saver(variables_to_resotre) else: restore_saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep=10) # ============================== # Set tf summary save path tboard_save_path = 'tboard/tusimple_lanenet_{:s}_{:s}'.format( net_flag, version_flag) os.makedirs(tboard_save_path, exist_ok=True) # Set sess configuration # ============================== config GPU sess_config = tf.ConfigProto(allow_soft_placement=True) # sess_config.gpu_options.per_process_gpu_memory_fraction = cfg.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = cfg.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' # ============================== sess = tf.Session(config=sess_config) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # ---------------------------------------------------------------- # # Set the training parameters import math one_epoch2step = math.ceil(cfg.TRAIN.TRAIN_SIZE / cfg.TRAIN.BATCH_SIZE) # 训练一个 epoch 需要的 batch 数量 total_epoch = math.ceil(cfg.TRAIN.STEPS / one_epoch2step) # 一共需要训练多少 epoch log.info('Global configuration is as follows:') log.info(cfg) max_acc = 0.9 save_num = 0 val_step = 0 # ================================================================ # # Train & Val # # ================================================================ # with sess.as_default(): # ============================== load pretrain model # if weights_path is None: # log.info('Training from scratch') # sess.run(tf.global_variables_initializer()) # elif net_flag == 'vgg' and weights_path is None: # load_pretrained_weights(tf.trainable_variables(), './data/vgg16.npy', sess) # elif scratch: # 从头开始训练,类似 Caffe 的 --weights # sess.run(tf.global_variables_initializer()) # log.info('Restore model from last model checkpoint {:s}, scratch'.format(weights_path)) # try: # restore_saver.restore(sess=sess, save_path=weights_path) # except: # log.info('model maybe is not exist!') # else: # 继续训练,类似 Caffe 的 --snapshot # log.info('Restore model from last model checkpoint {:s}'.format(weights_path)) # try: # restore_saver.restore(sess=sess, save_path=weights_path) # except: # log.info('model maybe is not exist!') sess.run(tf.global_variables_initializer()) # ============================== for epoch in range(total_epoch): # ================================================================ # # Train # # ================================================================ # train_epoch_loss = [] pbar_train = tqdm(trainset) train_t_start = time.time() for gt_imgs, binary_gt_labels, instance_gt_labels in pbar_train: _, global_step_val, train_loss, train_accuracy_figure, train_fn_figure, train_fp_figure, \ lr, train_summary, train_binary_loss, train_instance_loss, \ train_embeddings, train_binary_seg_imgs, train_l2_loss = \ sess.run([optimizer, global_step, train_total_loss, train_accuracy, train_fn, train_fp, learning_rate, train_merge_summary_op, train_binary_seg_loss, train_disc_loss, train_pix_embedding, train_prediction, train_l2_reg_loss], feed_dict={train_input_tensor: gt_imgs, train_binary_label_tensor: binary_gt_labels, train_instance_label_tensor: instance_gt_labels} ) # ============================== 透心凉,心飞扬 if math.isnan(train_loss) or math.isnan( train_binary_loss) or math.isnan(train_instance_loss): log.error('cost is: {:.5f}'.format(train_loss)) log.error( 'binary cost is: {:.5f}'.format(train_binary_loss)) log.error( 'instance cost is: {:.5f}'.format(train_instance_loss)) return # ============================== train_epoch_loss.append(train_loss) summary_writer.add_summary(summary=train_summary, global_step=global_step_val) pbar_train.set_description( ("train loss: %.4f, learn rate: %e") % (train_loss, lr)) train_cost_time = time.time() - train_t_start mean_train_loss = np.mean(train_epoch_loss) log.info( 'MEAN Train: total_loss= {:6f} mean_cost_time= {:5f}s'.format( mean_train_loss, train_cost_time)) # ---------------------------------------------------------------- # # ================================================================ # # Val # # ================================================================ # # 每隔 epoch 次,测试整个验证集 pbar_val = tqdm(valset) val_epoch_loss = [] val_epoch_binary_loss = [] val_epoch_instance_loss = [] val_epoch_accuracy_figure = [] val_epoch_fp_figure = [] val_epoch_fn_figure = [] val_t_start = time.time() for val_images, val_binary_labels, val_instance_labels in pbar_val: # validation part val_step += 1 val_summary, \ val_loss, val_binary_loss, val_instance_loss, \ val_accuracy_figure, val_fn_figure, val_fp_figure = \ sess.run([val_merge_summary_op, val_total_loss, val_binary_seg_loss, val_disc_loss, val_accuracy, val_fn, val_fp], feed_dict={val_input_tensor: val_images, val_binary_label_tensor: val_binary_labels, val_instance_label_tensor: val_instance_labels} ) # ============================== 透心凉,心飞扬 if math.isnan(val_loss) or math.isnan( val_binary_loss) or math.isnan(val_instance_loss): log.error('cost is: {:.5f}'.format(val_loss)) log.error('binary cost is: {:.5f}'.format(val_binary_loss)) log.error( 'instance cost is: {:.5f}'.format(val_instance_loss)) return # ============================== summary_writer.add_summary(summary=val_summary, global_step=val_step) pbar_val.set_description(("val loss: %.4f, accuracy: %.4f") % (val_loss, val_accuracy_figure)) val_epoch_loss.append(val_loss) val_epoch_binary_loss.append(val_binary_loss) val_epoch_instance_loss.append(val_instance_loss) val_epoch_accuracy_figure.append(val_accuracy_figure) val_epoch_fp_figure.append(val_fp_figure) val_epoch_fn_figure.append(val_fn_figure) val_cost_time = time.time() - val_t_start mean_val_loss = np.mean(val_epoch_loss) mean_val_binary_loss = np.mean(val_epoch_binary_loss) mean_val_instance_loss = np.mean(val_epoch_instance_loss) mean_val_accuracy_figure = np.mean(val_epoch_accuracy_figure) mean_val_fp_figure = np.mean(val_epoch_fp_figure) mean_val_fn_figure = np.mean(val_epoch_fn_figure) # ============================== if mean_val_accuracy_figure > max_acc: max_acc = mean_val_accuracy_figure if save_num < 3: # 前三次不算 max_acc = 0.9 log.info( 'MAX_ACC change to {}'.format(mean_val_accuracy_figure)) model_save_path_max = ops.join( model_save_dir, 'tusimple_lanenet_{}.ckpt'.format( mean_val_accuracy_figure)) saver.save(sess=sess, save_path=model_save_path_max, global_step=global_step) save_num += 1 # ============================== log.info( '=> Epoch: {}, MEAN Val: total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}' ' mean_cost_time= {:5f}s '.format( epoch, mean_val_loss, mean_val_binary_loss, mean_val_instance_loss, mean_val_accuracy_figure, mean_val_fp_figure, mean_val_fn_figure, val_cost_time)) # ---------------------------------------------------------------- # return
def train_lanenet(dataset_dir, weights_path=None, net_flag='vgg', version_flag='', scratch=False): """ Train LaneNet With One GPU :param dataset_dir: :param weights_path: :param net_flag: :param version_flag: :param scratch: :return: """ train_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder( dataset_dir=dataset_dir, flags='train') val_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder( dataset_dir=dataset_dir, flags='val') # ================================================================ # # Define Network # # ================================================================ # train_net = lanenet.LaneNet(net_flag=net_flag, phase='train', reuse=tf.AUTO_REUSE) val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True) # ---------------------------------------------------------------- # # ================================================================ # # Train Input & Output # # ================================================================ # # set compute graph node for training train_images, train_binary_labels, train_instance_labels = train_dataset.inputs( CFG.TRAIN.BATCH_SIZE) train_compute_ret = train_net.compute_loss( input_tensor=train_images, binary_label=train_binary_labels, instance_label=train_instance_labels, name='lanenet_model') train_total_loss = train_compute_ret['total_loss'] train_binary_seg_loss = train_compute_ret['binary_seg_loss'] # 语义分割 loss train_disc_loss = train_compute_ret[ 'discriminative_loss'] # embedding loss train_pix_embedding = train_compute_ret[ 'instance_seg_logits'] # embedding feature, HxWxN train_l2_reg_loss = train_compute_ret['l2_reg_loss'] train_prediction_logits = train_compute_ret[ 'binary_seg_logits'] # 语义分割结果,HxWx2 train_prediction_score = tf.nn.softmax(logits=train_prediction_logits) train_prediction = tf.argmax(train_prediction_score, axis=-1) # 语义分割二值图 train_accuracy = evaluate_model_utils.calculate_model_precision( train_compute_ret['binary_seg_logits'], train_binary_labels) train_fp = evaluate_model_utils.calculate_model_fp( train_compute_ret['binary_seg_logits'], train_binary_labels) train_fn = evaluate_model_utils.calculate_model_fn( train_compute_ret['binary_seg_logits'], train_binary_labels) train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_prediction) # (I - min) * 255 / (max -min), 归一化到0-255 train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary( img=train_pix_embedding) # (I - min) * 255 / (max -min), 归一化到0-255 # ---------------------------------------------------------------- # # ================================================================ # # Define Optimizer # # ================================================================ # # set optimizer global_step = tf.Variable(0, trainable=False, name='global_step') # learning_rate = tf.train.cosine_decay_restarts( # 余弦衰减 # learning_rate=CFG.TRAIN.LEARNING_RATE, # 初始学习率 # global_step=global_step, # 当前迭代次数 # first_decay_steps=CFG.TRAIN.STEPS/3, # 首次衰减周期 # t_mul=2.0, # 随后每次衰减周期倍数 # m_mul=1.0, # 随后每次初始学习率倍数 # alpha = 0.1, # 最小的学习率=alpha*learning_rate # ) learning_rate = tf.train.polynomial_decay( # 多项式衰减 learning_rate=CFG.TRAIN.LEARNING_RATE, # 初始学习率 global_step=global_step, # 当前迭代次数 decay_steps=CFG.TRAIN.STEPS / 4, # 在迭代到该次数实际,学习率衰减为 learning_rate * dacay_rate end_learning_rate=CFG.TRAIN.LEARNING_RATE / 10, # 最小的学习率 power=0.9, cycle=True) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # for batch normalization with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=CFG.TRAIN.MOMENTUM).minimize( loss=train_total_loss, var_list=tf.trainable_variables(), global_step=global_step) # ---------------------------------------------------------------- # # ================================================================ # # Train Summary # # ================================================================ # train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=train_total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=train_accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=train_binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=train_disc_loss) train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn) train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp) train_binary_seg_ret_img = tf.summary.image( name='train_binary_seg_ret', tensor=train_binary_seg_ret_for_summary) train_embedding_feats_ret_img = tf.summary.image( name='train_embedding_feats_ret', tensor=train_embedding_ret_for_summary) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_cost_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar, train_fn_scalar, train_fp_scalar, train_binary_seg_ret_img, train_embedding_feats_ret_img, learning_rate_scalar ]) # ---------------------------------------------------------------- # # ================================================================ # # Val Input & Output # # ================================================================ # # set compute graph node for validation val_images, val_binary_labels, val_instance_labels = val_dataset.inputs( CFG.TEST.BATCH_SIZE) val_compute_ret = val_net.compute_loss(input_tensor=val_images, binary_label=val_binary_labels, instance_label=val_instance_labels, name='lanenet_model') val_total_loss = val_compute_ret['total_loss'] val_binary_seg_loss = val_compute_ret['binary_seg_loss'] val_disc_loss = val_compute_ret['discriminative_loss'] val_pix_embedding = val_compute_ret['instance_seg_logits'] val_prediction_logits = val_compute_ret['binary_seg_logits'] val_prediction_score = tf.nn.softmax(logits=val_prediction_logits) val_prediction = tf.argmax(val_prediction_score, axis=-1) val_accuracy = evaluate_model_utils.calculate_model_precision( val_compute_ret['binary_seg_logits'], val_binary_labels) val_fp = evaluate_model_utils.calculate_model_fp( val_compute_ret['binary_seg_logits'], val_binary_labels) val_fn = evaluate_model_utils.calculate_model_fn( val_compute_ret['binary_seg_logits'], val_binary_labels) val_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary( img=val_prediction) val_embedding_ret_for_summary = evaluate_model_utils.get_image_summary( img=val_pix_embedding) # ---------------------------------------------------------------- # # ================================================================ # # VAL Summary # # ================================================================ # val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=val_total_loss) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=val_accuracy) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=val_binary_seg_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=val_disc_loss) val_fn_scalar = tf.summary.scalar(name='val_fn', tensor=val_fn) val_fp_scalar = tf.summary.scalar(name='val_fp', tensor=val_fp) val_binary_seg_ret_img = tf.summary.image( name='val_binary_seg_ret', tensor=val_binary_seg_ret_for_summary) val_embedding_feats_ret_img = tf.summary.image( name='val_embedding_feats_ret', tensor=val_embedding_ret_for_summary) val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar, val_fn_scalar, val_fp_scalar, val_binary_seg_ret_img, val_embedding_feats_ret_img ]) # ---------------------------------------------------------------- # # ================================================================ # # Config Saver & Session # # ================================================================ # # Set tf model save path model_save_dir = 'model/tusimple_lanenet_{:s}_{:s}'.format( net_flag, version_flag) os.makedirs(model_save_dir, exist_ok=True) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format( net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # ============================== if scratch: """ 删除 Momentum 的参数, 注意这里保存的 meta 文件也会删了 tensorflow 在 save model 的时候,如果选择了 global_step 选项,会 global_step 值也保存下来, 然后 restore 的时候也就会接着这个 global_step 继续训练下去,因此需要去掉 """ variables = tf.contrib.framework.get_variables_to_restore() variables_to_resotre = [ v for v in variables if 'Momentum' not in v.name.split('/')[-1] ] variables_to_resotre = [ v for v in variables_to_resotre if 'global_step' not in v.name.split('/')[-1] ] restore_saver = tf.train.Saver(variables_to_resotre) else: restore_saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep=10) # ============================== # Set tf summary save path tboard_save_path = 'tboard/tusimple_lanenet_{:s}_{:s}'.format( net_flag, version_flag) os.makedirs(tboard_save_path, exist_ok=True) # Set sess configuration # ============================== config GPU sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' # ============================== sess = tf.Session(config=sess_config) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # ---------------------------------------------------------------- # # Set the training parameters import math train_steps = CFG.TRAIN.STEPS val_steps = math.ceil(CFG.TRAIN.VAL_SIZE / CFG.TEST.BATCH_SIZE) # 测试一个 epoch 需要的 batch 数量 one_epoch2step = math.ceil(CFG.TRAIN.TRAIN_SIZE / CFG.TRAIN.BATCH_SIZE) # 训练一个 epoch 需要的 batch 数量 log.info('Global configuration is as follows:') log.info(CFG) max_acc = 0.9 save_num = 0 # ================================================================ # # Train & Val # # ================================================================ # with sess.as_default(): # ============================== load pretrain model if weights_path is None: log.info('Training from scratch') sess.run(tf.global_variables_initializer()) elif net_flag == 'vgg' and weights_path is None: load_pretrained_weights(tf.trainable_variables(), './data/vgg16.npy', sess) elif scratch: # 从头开始训练,类似 Caffe 的 --weights sess.run(tf.global_variables_initializer()) log.info('Restore model from last model checkpoint {:s}, scratch'. format(weights_path)) try: restore_saver.restore(sess=sess, save_path=weights_path) except: log.info('model maybe is not exist!') else: # 继续训练,类似 Caffe 的 --snapshot log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) try: restore_saver.restore(sess=sess, save_path=weights_path) except: log.info('model maybe is not exist!') # ============================== train_cost_time_mean = [] # 统计一个 batch 训练耗时 for step in range(train_steps): # ================================================================ # # Train # # ================================================================ # t_start = time.time() _, train_loss, train_accuracy_figure, train_fn_figure, train_fp_figure, \ lr, train_summary, train_binary_loss, \ train_instance_loss, train_embeddings, train_binary_seg_imgs, train_gt_imgs, \ train_binary_gt_labels, train_instance_gt_labels, train_l2_loss = \ sess.run([optimizer, train_total_loss, train_accuracy, train_fn, train_fp, learning_rate, train_merge_summary_op, train_binary_seg_loss, train_disc_loss, train_pix_embedding, train_prediction, train_images, train_binary_labels, train_instance_labels, train_l2_reg_loss]) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) # ============================== 透心凉,心飞扬 if math.isnan(train_loss) or math.isnan( train_binary_loss) or math.isnan(train_instance_loss): log.error('cost is: {:.5f}'.format(train_loss)) log.error('binary cost is: {:.5f}'.format(train_binary_loss)) log.error( 'instance cost is: {:.5f}'.format(train_instance_loss)) return # ============================== summary_writer.add_summary(summary=train_summary, global_step=step) # 每隔 DISPLAY_STEP 次,打印 loss 值 if step % CFG.TRAIN.DISPLAY_STEP == 0: epoch_num = step // one_epoch2step log.info( 'Epoch: {:d} Step: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} l2_reg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}' ' lr= {:6f} mean_cost_time= {:5f}s '.format( epoch_num + 1, step + 1, train_loss, train_binary_loss, train_instance_loss, train_l2_loss, train_accuracy_figure, train_fp_figure, train_fn_figure, lr, np.mean(train_cost_time_mean))) train_cost_time_mean.clear() # # 每隔 VAL_DISPLAY_STEP 次,保存模型,保存当前 batch 训练结果图片 # if step % CFG.TRAIN.VAL_DISPLAY_STEP == 0: # saver.save(sess=sess, save_path=model_save_path, global_step=global_step) # global_step 会保存 global_step 信息 # record_training_intermediate_result( # gt_images=train_gt_imgs, gt_binary_labels=train_binary_gt_labels, # gt_instance_labels=train_instance_gt_labels, binary_seg_images=train_binary_seg_imgs, # pix_embeddings=train_embeddings # ) # ---------------------------------------------------------------- # # ================================================================ # # Val # # ================================================================ # # 每隔 VAL_DISPLAY_STEP 次,测试整个验证集 if step % CFG.TRAIN.VAL_DISPLAY_STEP == 0: val_t_start = time.time() val_cost_time = 0 mean_val_c = 0.0 mean_val_binary_loss = 0.0 mean_val_instance_loss = 0.0 mean_val_accuracy_figure = 0.0 mean_val_fp_figure = 0.0 mean_val_fn_figure = 0.0 for val_step in range(val_steps): # validation part val_c, val_accuracy_figure, val_fn_figure, val_fp_figure, \ val_summary, val_binary_loss, val_instance_loss, \ val_embeddings, val_binary_seg_imgs, val_gt_imgs, \ val_binary_gt_labels, val_instance_gt_labels = \ sess.run([val_total_loss, val_accuracy, val_fn, val_fp, val_merge_summary_op, val_binary_seg_loss, val_disc_loss, val_pix_embedding, val_prediction, val_images, val_binary_labels, val_instance_labels]) # ============================== 透心凉,心飞扬 if math.isnan(val_c) or math.isnan( val_binary_loss) or math.isnan(val_instance_loss): log.error('cost is: {:.5f}'.format(val_c)) log.error( 'binary cost is: {:.5f}'.format(val_binary_loss)) log.error('instance cost is: {:.5f}'.format( val_instance_loss)) return # ============================== # if val_step == 0: # record_training_intermediate_result( # gt_images=val_gt_imgs, gt_binary_labels=val_binary_gt_labels, # gt_instance_labels=val_instance_gt_labels, binary_seg_images=val_binary_seg_imgs, # pix_embeddings=val_embeddings, flag='val' # ) cost_time = time.time() - val_t_start val_cost_time += cost_time mean_val_c += val_c mean_val_binary_loss += val_binary_loss mean_val_instance_loss += val_instance_loss mean_val_accuracy_figure += val_accuracy_figure mean_val_fp_figure += val_fp_figure mean_val_fn_figure += val_fn_figure summary_writer.add_summary(summary=val_summary, global_step=step) mean_val_c /= val_steps mean_val_binary_loss /= val_steps mean_val_instance_loss /= val_steps mean_val_accuracy_figure /= val_steps mean_val_fp_figure /= val_steps mean_val_fn_figure /= val_steps # ============================== if mean_val_accuracy_figure > max_acc: max_acc = mean_val_accuracy_figure if save_num < 3: # 前三次不算 max_acc = 0.9 log.info('MAX_ACC change to {}'.format( mean_val_accuracy_figure)) model_save_path_max = ops.join( model_save_dir, 'tusimple_lanenet_{}.ckpt'.format( mean_val_accuracy_figure)) saver.save(sess=sess, save_path=model_save_path_max, global_step=global_step) save_num += 1 # ============================== log.info( 'MEAN Val: total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}' ' mean_cost_time= {:5f}s '.format( mean_val_c, mean_val_binary_loss, mean_val_instance_loss, mean_val_accuracy_figure, mean_val_fp_figure, mean_val_fn_figure, val_cost_time)) # ---------------------------------------------------------------- # return