def test_lanenet_batch(image_list, weights_path, batch_size, use_gpu, net_flag='vgg'): """ :param image_list: :param weights_path: :param batch_size: :param use_gpu: :param net_flag: :return: """ assert ops.exists(image_list), '{:s} not exist'.format(image_list) log.info('开始加载数据集列表...') test_dataset = lanenet_data_processor.DataSet(image_list, traing=False) # ============================== gt_label_binary_list = [] with open(image_list, 'r') as file: for _info in file: info_tmp = _info.strip(' ').split() gt_label_binary_list.append(info_tmp[1]) # ============================== input_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 256, 512, 3], name='input_tensor') binary_label_tensor = tf.placeholder(dtype=tf.int64, shape=[None, 256, 512, 1], name='binary_input_label') phase_tensor = tf.constant('test', tf.string) net = lanenet.LaneNet(phase=phase_tensor, net_flag=net_flag) binary_seg_ret, instance_seg_ret, recall_ret, false_positive, false_negative, precision_ret, accuracy_ret = \ net.compute_acc(input_tensor=input_tensor, binary_label_tensor=binary_label_tensor, name='lanenet_model') saver = tf.train.Saver() # ============================== # Set sess configuration if use_gpu: sess_config = tf.ConfigProto(device_count={'GPU': 1}) else: sess_config = tf.ConfigProto(device_count={'GPU': 0}) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' # ============================== sess = tf.Session(config=sess_config) with sess.as_default(): saver.restore(sess=sess, save_path=weights_path) epoch_nums = int(math.ceil(test_dataset._dataset_size / batch_size)) mean_accuracy = 0.0 mean_recall = 0.0 mean_precision = 0.0 mean_fp = 0.0 mean_fn = 0.0 total_num = 0 t_start = time.time() for epoch in range(epoch_nums): gt_imgs, binary_gt_labels, instance_gt_labels = test_dataset.next_batch(batch_size) if net_flag == 'vgg': image_list_epoch = [tmp / 127.5 - 1.0 for tmp in gt_imgs] elif net_flag == 'mobilenet_v2': image_list_epoch = [tmp - [103.939, 116.779, 123.68] for tmp in gt_imgs] binary_seg_images, instance_seg_images, recall, fp, fn, precision, accuracy = sess.run( [binary_seg_ret, instance_seg_ret, recall_ret, false_positive, false_negative, precision_ret, accuracy_ret], feed_dict={input_tensor: image_list_epoch, binary_label_tensor: binary_gt_labels}) # ============================== out_dir = 'H:/Other_DataSets/TuSimple/out/' dst_binary_image_path = ops.join(out_dir,gt_label_binary_list[epoch]) root_dir = ops.dirname(ops.abspath(dst_binary_image_path)) if not os.path.exists(root_dir): os.makedirs(root_dir) cv2.imwrite(dst_binary_image_path, binary_seg_images[0] * 255) # ============================== print(recall, fp, fn) mean_accuracy += accuracy mean_precision += precision mean_recall += recall mean_fp += fp mean_fn += fn total_num += len(gt_imgs) t_cost = time.time() - t_start mean_accuracy = mean_accuracy / epoch_nums mean_precision = mean_precision / epoch_nums mean_recall = mean_recall / epoch_nums mean_fp = mean_fp / epoch_nums mean_fn = mean_fn / epoch_nums print('测试 {} 张图片,耗时{},{}_recall = {}, precision = {}, accuracy = {}, fp = {}, fn = {}, '.format( total_num, t_cost, net_flag, mean_recall, mean_precision, mean_accuracy, mean_fp, mean_fn)) sess.close()
def train_net(dataset_dir, weights_path=None, net_flag='vgg'): """ :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ train_dataset_file = ops.join(dataset_dir, 'train.txt') val_dataset_file = ops.join(dataset_dir, 'val.txt') assert ops.exists(train_dataset_file) train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) input_tensor = tf.placeholder( dtype=tf.float32, shape=[None, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3], name='input_tensor') binary_label_tensor = tf.placeholder( dtype=tf.int64, shape=[None, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1], name='binary_input_label') instance_label_tensor = tf.placeholder( dtype=tf.float32, shape=[None, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH], name='instance_input_label') phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') # net = lanenet_instance_segmentation.LaneNetInstanceSeg(net_flag=net_flag, phase=phase) net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase) # calculate the loss compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor, instance_label=instance_label_tensor, name='lanenet_loss') total_loss = compute_ret['total_loss'] binary_seg_loss = compute_ret['binary_seg_loss'] disc_loss = compute_ret['discriminative_loss'] pix_embedding = compute_ret['instance_seg_logits'] # calculate the accuracy out_logits = compute_ret['binary_seg_logits'] out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax(out_logits, axis=-1) out = tf.argmax(out_logits, axis=-1) out = tf.expand_dims(out, axis=-1) accuracy = tf.add(binary_label_tensor, -1 * out) accuracy = tf.count_nonzero(accuracy, axis=[1, 2, 3]) accuracy = tf.add( tf.constant(1, dtype=tf.float64), -1 * tf.divide(accuracy, CFG.TRAIN.IMG_HEIGHT * CFG.TRAIN.IMG_WIDTH)) accuracy = tf.reduce_mean(accuracy, axis=0) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step, 5000, 0.96, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer( learning_rate=learning_rate).minimize( loss=total_loss, var_list=tf.trainable_variables(), global_step=global_step) # Set tf saver saver = tf.train.Saver() model_save_dir = 'model/kitti_lanenet' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'kitti_lanenet_{:s}_{:s}.ckpt'.format(net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # Set tf summary tboard_save_path = 'tboard/kitti_lanenet/{:s}'.format(net_flag) if not ops.exists(tboard_save_path): os.makedirs(tboard_save_path) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss) val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=binary_seg_loss) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=disc_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=disc_loss) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_cost_scalar, learning_rate_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar ]) val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar ]) # Set sess configuration sess_config = tf.ConfigProto(device_count={'GPU': 1}) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # Set the training parameters train_epochs = CFG.TRAIN.EPOCHS log.info('Global configuration is as follows:') log.info(CFG) with sess.as_default(): tf.train.write_graph( graph_or_graph_def=sess.graph, logdir='', name='{:s}/lanenet_model.pb'.format(model_save_dir)) if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) # 加载预训练参数 if net_flag == 'vgg': pretrained_weights = np.load( '/home/baidu/Silly_Project/ICode/baidu/beec/semantic-road-estimation/data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights_key = vv.name.split('/')[-3] try: weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue train_cost_time_mean = [] val_cost_time_mean = [] for epoch in range(train_epochs): # training part t_start = time.time() gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) gt_imgs = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs ] gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs] binary_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in binary_gt_labels ] binary_gt_labels = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels ] instance_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels ] phase_train = 'train' _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \ sess.run([optimizer, total_loss, accuracy, train_merge_summary_op, binary_seg_loss, disc_loss, pix_embedding, out_logits_out], feed_dict={input_tensor: gt_imgs, binary_label_tensor: binary_gt_labels, instance_label_tensor: instance_gt_labels, phase: phase_train}) if math.isnan(c) or math.isnan(binary_loss) or math.isnan( instance_loss): log.error('cost is: {:.5f}'.format(c)) log.error('binary cost is: {:.5f}'.format(binary_loss)) log.error('instance cost is: {:.5f}'.format(instance_loss)) cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('nan_instance_label.png', instance_gt_labels[0]) cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255) cv2.imwrite('nan_embedding.png', embedding[0]) return if epoch % 100 == 0: cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255) cv2.imwrite('instance_label.png', instance_gt_labels[0]) cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255) cv2.imwrite('embedding.png', embedding[0]) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) summary_writer.add_summary(summary=train_summary, global_step=epoch) # validation part gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \ = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE) gt_imgs_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs_val ] gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val] binary_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp) for tmp in binary_gt_labels_val ] binary_gt_labels_val = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val ] instance_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels_val ] phase_val = 'test' t_start_val = time.time() c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \ sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss], feed_dict={input_tensor: gt_imgs_val, binary_label_tensor: binary_gt_labels_val, instance_label_tensor: instance_gt_labels_val, phase: phase_val}) if epoch % 100 == 0: cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN) summary_writer.add_summary(val_summary, global_step=epoch) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: log.info( 'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}' ' mean_cost_time= {:5f}s '.format( epoch + 1, c, binary_loss, instance_loss, train_accuracy, np.mean(train_cost_time_mean))) train_cost_time_mean.clear() if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0: log.info( 'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} ' 'mean_cost_time= {:5f}s '.format( epoch + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy, np.mean(val_cost_time_mean))) val_cost_time_mean.clear() if epoch % 2000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=epoch) sess.close() return
def train_net(dataset_dir, weights_path=None, net_flag='vgg'): """ :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ train_dataset_file = ops.join(dataset_dir, 'train_gt.txt') val_dataset_file = ops.join(dataset_dir, 'val_gt.txt') assert ops.exists(train_dataset_file) train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) input_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3 ], name='input_tensor') instance_label_tensor = tf.placeholder(dtype=tf.int64, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH ], name='instance_input_label') existence_label_tensor = tf.placeholder(dtype=tf.float32, shape=[CFG.TRAIN.BATCH_SIZE, 4], name='existence_input_label') phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase) # calculate the loss compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=instance_label_tensor, existence_label=existence_label_tensor, name='lanenet_loss') total_loss = compute_ret['total_loss'] instance_loss = compute_ret['instance_seg_loss'] existence_loss = compute_ret['existence_pre_loss'] existence_logits = compute_ret['existence_logits'] # calculate the accuracy out_logits = compute_ret['instance_seg_logits'] out_logits_ref = out_logits out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax(out_logits, axis=-1) # 8 x 288 x 800 pred_0 = tf.count_nonzero( tf.multiply(tf.cast(tf.equal(instance_label_tensor, 0), tf.int64), tf.cast(tf.equal(out_logits_out, 0), tf.int64))) pred_1 = tf.count_nonzero( tf.multiply(tf.cast(tf.equal(instance_label_tensor, 1), tf.int64), tf.cast(tf.equal(out_logits_out, 1), tf.int64))) pred_2 = tf.count_nonzero( tf.multiply(tf.cast(tf.equal(instance_label_tensor, 2), tf.int64), tf.cast(tf.equal(out_logits_out, 2), tf.int64))) pred_3 = tf.count_nonzero( tf.multiply(tf.cast(tf.equal(instance_label_tensor, 3), tf.int64), tf.cast(tf.equal(out_logits_out, 3), tf.int64))) pred_4 = tf.count_nonzero( tf.multiply(tf.cast(tf.equal(instance_label_tensor, 4), tf.int64), tf.cast(tf.equal(out_logits_out, 4), tf.int64))) gt_all = tf.count_nonzero( tf.cast(tf.greater(instance_label_tensor, 0), tf.int64)) gt_back = tf.count_nonzero( tf.cast(tf.equal(instance_label_tensor, 0), tf.int64)) pred_all = tf.add(tf.add(tf.add(pred_1, pred_2), pred_3), pred_4) accuracy = tf.divide(pred_all, gt_all) accuracy_back = tf.divide(pred_0, gt_back) # Compute mIoU of Lanes overlap_1 = pred_1 union_1 = tf.add( tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 1), tf.int64)), tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 1), tf.int64))) union_1 = tf.subtract(union_1, overlap_1) IoU_1 = tf.divide(overlap_1, union_1) overlap_2 = pred_2 union_2 = tf.add( tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 2), tf.int64)), tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 2), tf.int64))) union_2 = tf.subtract(union_2, overlap_2) IoU_2 = tf.divide(overlap_2, union_2) overlap_3 = pred_3 union_3 = tf.add( tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 3), tf.int64)), tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 3), tf.int64))) union_3 = tf.subtract(union_3, overlap_3) IoU_3 = tf.divide(overlap_3, union_3) overlap_4 = pred_4 union_4 = tf.add( tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 4), tf.int64)), tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 4), tf.int64))) union_4 = tf.subtract(union_4, overlap_4) IoU_4 = tf.divide(overlap_4, union_4) IoU = tf.reduce_mean(tf.stack([IoU_1, IoU_2, IoU_3, IoU_4])) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.polynomial_decay(CFG.TRAIN.LEARNING_RATE, global_step, 90100, power=0.9) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9).minimize(loss=total_loss, var_list=tf.trainable_variables(), global_step=global_step) # Set tf saver saver = tf.train.Saver() model_save_dir = 'model/culane_lanenet/culane_scnn' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'culane_lanenet_{:s}_{:s}.ckpt'.format( net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # Set sess configuration sess_config = tf.ConfigProto(device_count={'GPU': 4}) # device_count={'GPU': 1} sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) # Set the training parameters train_epochs = CFG.TRAIN.EPOCHS log.info('Global configuration is as follows:') log.info(CFG) with sess.as_default(): if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) # 加载预训练参数 if net_flag == 'vgg' and weights_path is None: pretrained_weights = np.load('./data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights_key = vv.name.split('/')[-3] try: weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue train_cost_time_mean = [] train_instance_loss_mean = [] train_existence_loss_mean = [] train_accuracy_mean = [] train_accuracy_back_mean = [] val_cost_time_mean = [] val_instance_loss_mean = [] val_existence_loss_mean = [] val_accuracy_mean = [] val_accuracy_back_mean = [] val_IoU_mean = [] for epoch in range(train_epochs): # training part t_start = time.time() gt_imgs, instance_gt_labels, existence_gt_labels = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) gt_imgs = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_CUBIC) for tmp in gt_imgs ] gt_imgs = [(tmp - VGG_MEAN) for tmp in gt_imgs] instance_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels ] phase_train = 'train' _, c, train_accuracy, train_accuracy_back, train_instance_loss, train_existence_loss, binary_seg_img = \ sess.run([optimizer, total_loss, accuracy, accuracy_back, instance_loss, existence_loss, out_logits_out], feed_dict={input_tensor: gt_imgs, instance_label_tensor: instance_gt_labels, existence_label_tensor: existence_gt_labels, phase: phase_train}) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) train_instance_loss_mean.append(train_instance_loss) train_existence_loss_mean.append(train_existence_loss) train_accuracy_mean.append(train_accuracy) train_accuracy_back_mean.append(train_accuracy_back) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: print( 'Epoch: {:d} loss_ins= {:6f} ({:6f}) loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) accuracy_back= {:6f} ({:6f})' ' mean_time= {:5f}s '.format( epoch + 1, train_instance_loss, np.mean(train_instance_loss_mean), train_existence_loss, np.mean(train_existence_loss_mean), train_accuracy, np.mean(train_accuracy_mean), train_accuracy_back, np.mean(train_accuracy_back_mean), np.mean(train_cost_time_mean))) if epoch % 500 == 0: train_cost_time_mean.clear() train_instance_loss_mean.clear() train_existence_loss_mean.clear() train_accuracy_mean.clear() train_accuracy_back_mean.clear() if epoch % 1000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=epoch) if epoch % 10000 != 0 or epoch == 0: continue for epoch_val in range(int(9675 / 8.0)): # validation part gt_imgs_val, instance_gt_labels_val, existence_gt_labels_val \ = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE) gt_imgs_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_CUBIC) for tmp in gt_imgs_val ] gt_imgs_val = [(tmp - VGG_MEAN) for tmp in gt_imgs_val] instance_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels_val ] phase_val = 'test' t_start_val = time.time() c_val, val_accuracy, val_accuracy_back, val_IoU, val_instance_loss, val_existence_loss = \ sess.run([total_loss, accuracy, accuracy_back, IoU, instance_loss, existence_loss], feed_dict={input_tensor: gt_imgs_val, instance_label_tensor: instance_gt_labels_val, existence_label_tensor: existence_gt_labels_val, phase: phase_val}) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) val_instance_loss_mean.append(val_instance_loss) val_existence_loss_mean.append(val_existence_loss) val_accuracy_mean.append(val_accuracy) val_accuracy_back_mean.append(val_accuracy_back) val_IoU_mean.append(val_IoU) if epoch_val % 1 == 0: print( 'Epoch_Val: {:d} loss_ins= {:6f} ({:6f}) ' 'loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) accuracy_back= {:6f} ({:6f}) mIoU= {:6f} ({:6f})' 'mean_time= {:5f}s '.format( epoch_val + 1, val_instance_loss, np.mean(val_instance_loss_mean), val_existence_loss, np.mean(val_existence_loss_mean), val_accuracy, np.mean(val_accuracy_mean), val_accuracy_back, np.mean(val_accuracy_back_mean), val_IoU, np.mean(val_IoU_mean), np.mean(val_cost_time_mean))) val_cost_time_mean.clear() val_instance_loss_mean.clear() val_existence_loss_mean.clear() val_accuracy_mean.clear() val_accuracy_back_mean.clear() val_IoU_mean.clear() sess.close() return
def train_net(dataset_dir, weights_path=None, net_flag='vgg'): """ :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ train_dataset_file = ops.join(dataset_dir, 'train_gt.txt') val_dataset_file = ops.join(dataset_dir, 'val_gt.txt') assert ops.exists(train_dataset_file) train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) input_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3 ], name='input_tensor') instance_label_tensor = tf.placeholder(dtype=tf.int64, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH ], name='instance_input_label') existence_label_tensor = tf.placeholder(dtype=tf.float32, shape=[CFG.TRAIN.BATCH_SIZE, 4], name='existence_input_label') phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase) # calculate the loss compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=instance_label_tensor, existence_label=existence_label_tensor, name='lanenet_loss') total_loss = compute_ret['total_loss'] instance_loss = compute_ret['instance_seg_loss'] existence_loss = compute_ret['existence_pre_loss'] existence_logits = compute_ret['existence_logits'] # calculate the accuracy out_logits = compute_ret['instance_seg_logits'] out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax(out_logits, axis=-1) out = tf.argmax(out_logits, axis=-1) out = tf.expand_dims(out, axis=-1) idx = tf.where(tf.equal(instance_label_tensor, 1)) pix_cls_ret = tf.gather_nd(out, idx) accuracy = tf.count_nonzero(pix_cls_ret) accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64)) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step, 5000, 0.96, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer( learning_rate=learning_rate).minimize( loss=total_loss, var_list=tf.trainable_variables(), global_step=global_step) # Set tf saver saver = tf.train.Saver() model_save_dir = 'model/culane_lanenet/culane_scnn' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'culane_lanenet_{:s}_{:s}.ckpt'.format( net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # Set sess configuration sess_config = tf.ConfigProto(device_count={'GPU': 4}) # device_count={'GPU': 1} sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) # Set the training parameters train_epochs = CFG.TRAIN.EPOCHS log.info('Global configuration is as follows:') log.info(CFG) with sess.as_default(): tf.train.write_graph( graph_or_graph_def=sess.graph, logdir='', name='{:s}/lanenet_model.pb'.format(model_save_dir)) if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) # 加载预训练参数 if net_flag == 'vgg' and weights_path is None: pretrained_weights = np.load('./data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights_key = vv.name.split('/')[-3] try: weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue train_cost_time_mean = [] train_instance_loss_mean = [] train_existence_loss_mean = [] train_accuracy_mean = [] val_cost_time_mean = [] val_instance_loss_mean = [] val_existence_loss_mean = [] val_accuracy_mean = [] for epoch in range(train_epochs): # training part t_start = time.time() gt_imgs, instance_gt_labels, existence_gt_labels = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) gt_imgs = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs ] gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs] instance_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels ] phase_train = 'train' _, c, train_accuracy, train_instance_loss, train_existence_loss, binary_seg_img = \ sess.run([optimizer, total_loss, accuracy, instance_loss, existence_loss, out_logits_out], feed_dict={input_tensor: gt_imgs, instance_label_tensor: instance_gt_labels, existence_label_tensor: existence_gt_labels, phase: phase_train}) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) train_instance_loss_mean.append(train_instance_loss) train_existence_loss_mean.append(train_existence_loss) train_accuracy_mean.append(train_accuracy) # validation part gt_imgs_val, instance_gt_labels_val, existence_gt_labels_val \ = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE) gt_imgs_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs_val ] gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val] instance_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels_val ] phase_val = 'test' t_start_val = time.time() c_val, val_accuracy, val_instance_loss, val_existence_loss = \ sess.run([total_loss, accuracy, instance_loss, existence_loss], feed_dict={input_tensor: gt_imgs_val, instance_label_tensor: instance_gt_labels_val, existence_label_tensor: existence_gt_labels_val, phase: phase_val}) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) val_instance_loss_mean.append(val_instance_loss) val_existence_loss_mean.append(val_existence_loss) val_accuracy_mean.append(val_accuracy) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: print( 'Epoch: {:d} loss_ins= {:6f} ({:6f}) loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f})' ' mean_time= {:5f}s '.format( epoch + 1, train_instance_loss, np.mean(train_instance_loss_mean), train_existence_loss, np.mean(train_existence_loss_mean), train_accuracy, np.mean(train_accuracy_mean), np.mean(train_cost_time_mean))) # log.info if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0: print('Epoch_Val: {:d} loss_ins= {:6f} ({:6f}) ' 'loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f})' 'mean_time= {:5f}s '.format( epoch + 1, val_instance_loss, np.mean(val_instance_loss_mean), val_existence_loss, np.mean(val_existence_loss_mean), val_accuracy, np.mean(val_accuracy_mean), np.mean(val_cost_time_mean))) if epoch % 500 == 0: train_cost_time_mean.clear() train_instance_loss_mean.clear() train_existence_loss_mean.clear() train_accuracy_mean.clear() val_cost_time_mean.clear() val_instance_loss_mean.clear() val_existence_loss_mean.clear() val_accuracy_mean.clear() if epoch % 2000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=epoch) sess.close() return
def train_net(dataset_dir, weights_path=None, net_flag='vgg'): train_dataset_file = ops.join(dataset_dir, 'train_gt.txt') val_dataset_file = ops.join(dataset_dir, 'val_gt.txt') assert ops.exists(train_dataset_file) phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) net = lanenet_merge_model.LaneNet() tower_grads = [] global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.polynomial_decay(CFG.TRAIN.LEARNING_RATE, global_step, CFG.TRAIN.EPOCHS, power=0.9) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) img, label_instance, label_existence = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue( [img, label_instance, label_existence], capacity=2 * CFG.TRAIN.GPU_NUM, num_threads=CFG.TRAIN.CPU_NUM) val_img, val_label_instance, val_label_existence = val_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) val_batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue( [val_img, val_label_instance, val_label_existence], capacity=2 * CFG.TRAIN.GPU_NUM, num_threads=CFG.TRAIN.CPU_NUM) with tf.variable_scope(tf.get_variable_scope()): for i in range(CFG.TRAIN.GPU_NUM): with tf.device('/gpu:%d' % i): with tf.name_scope('tower_%d' % i): total_loss, instance_loss, existence_loss, accuracy, accuracy_back, _, out_logits_out, \ grad = forward(batch_queue, net, phase, optimizer) tower_grads.append(grad) val_op_total_loss, val_op_instance_loss, val_op_existence_loss, val_op_accuracy, \ val_op_accuracy_back, val_op_IoU, _, _ = forward(val_batch_queue, net, phase) grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads, global_step=global_step) train_cost_time_mean = [] train_instance_loss_mean = [] train_existence_loss_mean = [] train_accuracy_mean = [] train_accuracy_back_mean = [] saver = tf.train.Saver() model_save_dir = 'model/culane_lanenet/culane_scnn' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'culane_lanenet_{:s}_{:s}.ckpt'.format( net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM}, allow_soft_placement=True) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' with tf.Session(config=sess_config) as sess: with sess.as_default(): if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info( 'Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) # 加载预训练参数 if net_flag == 'vgg' and weights_path is None: pretrained_weights = np.load('./data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights = vv.name.split('/') if len(weights) >= 3 and weights[-3] in pretrained_weights: try: weights_key = weights[-3] weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue tf.train.start_queue_runners(sess=sess) for epoch in range(CFG.TRAIN.EPOCHS): t_start = time.time() _, c, train_accuracy, train_accuracy_back, train_instance_loss, train_existence_loss, binary_seg_img = \ sess.run([train_op, total_loss, accuracy, accuracy_back, instance_loss, existence_loss, out_logits_out], feed_dict={phase: 'train'}) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) train_instance_loss_mean.append(train_instance_loss) train_existence_loss_mean.append(train_existence_loss) train_accuracy_mean.append(train_accuracy) train_accuracy_back_mean.append(train_accuracy_back) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: print( 'Epoch: {:d} loss_ins= {:6f} ({:6f}) loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) ' 'accuracy_back= {:6f} ({:6f}) mean_time= {:5f}s '.format( epoch + 1, train_instance_loss, np.mean(train_instance_loss_mean), train_existence_loss, np.mean(train_existence_loss_mean), train_accuracy, np.mean(train_accuracy_mean), train_accuracy_back, np.mean(train_accuracy_back_mean), np.mean(train_cost_time_mean))) if epoch % 500 == 0: train_cost_time_mean.clear() train_instance_loss_mean.clear() train_existence_loss_mean.clear() train_accuracy_mean.clear() train_accuracy_back_mean.clear() if epoch % 1000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=epoch) if epoch % 10000 != 0 or epoch == 0: continue val_cost_time_mean = [] val_instance_loss_mean = [] val_existence_loss_mean = [] val_accuracy_mean = [] val_accuracy_back_mean = [] val_IoU_mean = [] for epoch_val in range( int( len(val_dataset) / CFG.TRAIN.VAL_BATCH_SIZE / CFG.TRAIN.GPU_NUM)): t_start_val = time.time() c_val, val_accuracy, val_accuracy_back, val_IoU, val_instance_loss, val_existence_loss = \ sess.run( [val_op_total_loss, val_op_accuracy, val_op_accuracy_back, val_op_IoU, val_op_instance_loss, val_op_existence_loss], feed_dict={phase: 'test'}) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) val_instance_loss_mean.append(val_instance_loss) val_existence_loss_mean.append(val_existence_loss) val_accuracy_mean.append(val_accuracy) val_accuracy_back_mean.append(val_accuracy_back) val_IoU_mean.append(val_IoU) if epoch_val % 1 == 0: print( 'Epoch_Val: {:d} loss_ins= {:6f} ({:6f}) ' 'loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) accuracy_back= {:6f} ({:6f}) ' 'mIoU= {:6f} ({:6f}) mean_time= {:5f}s '.format( epoch_val + 1, val_instance_loss, np.mean(val_instance_loss_mean), val_existence_loss, np.mean(val_existence_loss_mean), val_accuracy, np.mean(val_accuracy_mean), val_accuracy_back, np.mean(val_accuracy_back_mean), val_IoU, np.mean(val_IoU_mean), np.mean(val_cost_time_mean))) val_cost_time_mean.clear() val_instance_loss_mean.clear() val_existence_loss_mean.clear() val_accuracy_mean.clear() val_accuracy_back_mean.clear() val_IoU_mean.clear()
def train_net( dataset_dir, weights_path=None, net_flag='vgg', save_dir="./logs/train/lanenet", tboard_save_path="./tboard/lanenet", ignore_labels_path="/media/remus/datasets/AVMSnapshots/AVM/ignore_labels.png", my_checkpoint="true"): """ :param save_dir: :param ignore_labels_path: :param tboard_save_path: :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ train_dataset_file = ops.join(dataset_dir, 'train.txt') val_dataset_file = ops.join(dataset_dir, 'val.txt') assert ops.exists(train_dataset_file) # tf.enable_eager_execution() train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) with tf.device('/gpu:1'): input_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3 ], name='input_tensor') binary_label_tensor = tf.placeholder(dtype=tf.int64, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1 ], name='binary_input_label') instance_label_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH ], name='instance_input_label') phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase) # calculate the loss compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor, instance_label=instance_label_tensor, ignore_label=255, name='lanenet_model') total_loss = compute_ret['total_loss'] binary_seg_loss = compute_ret['binary_seg_loss'] disc_loss = compute_ret['discriminative_loss'] pix_embedding = compute_ret['instance_seg_logits'] # calculate the accuracy out_logits = compute_ret['binary_seg_logits'] out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax(out_logits, axis=-1) out = tf.argmax(out_logits, axis=-1) out = tf.expand_dims(out, axis=-1) idx = tf.where(tf.equal(binary_label_tensor, 1)) pix_cls_ret = tf.gather_nd(out, idx) accuracy = tf.count_nonzero(pix_cls_ret) accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64)) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step, CFG.TRAIN.LR_DECAY_STEPS, CFG.TRAIN.LR_DECAY_RATE, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9).minimize(loss=total_loss, var_list=tf.trainable_variables(), global_step=global_step) # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Set tf saver if my_checkpoint == "true": init_saver = tf.train.Saver() else: from correct_path_saver import restore_from_classification_checkpoint_fn, get_variables_available_in_checkpoint if weights_path is not None: # var_map = restore_from_classification_checkpoint_fn("lanenet_model/inference") available_var_map = (get_variables_available_in_checkpoint( tf.global_variables(), weights_path, include_global_step=False)) init_saver = tf.train.Saver(available_var_map) else: init_saver = tf.train.Saver() if not ops.exists(save_dir): os.makedirs(save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = '{:s}_lanenet_{:s}.ckpt'.format(net_flag, str(train_start_time)) model_save_path = ops.join(save_dir, model_name) # Set tf summary if not ops.exists(tboard_save_path): os.makedirs(tboard_save_path) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss) val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=binary_seg_loss) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=disc_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=disc_loss) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_cost_scalar, learning_rate_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar ]) val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar ]) # Set sess configuration sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' # sess_config.device_count = {'GPU': 0} sess = tf.Session(config=sess_config) # sess = tf_debug.TensorBoardDebugWrapperSession(sess=sess, # grpc_debug_server_addresses="remusm-pc:7000", # send_traceback_and_source_code=False) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # Set the training parameters train_epochs = CFG.TRAIN.EPOCHS tf.logging.info('Global configuration is as follows:') tf.logging.info(CFG) iter_saver = tf.train.Saver(max_to_keep=10) best_saver = tf.train.Saver(max_to_keep=3) with sess.as_default(): sess.run(tf.global_variables_initializer()) tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='', name='{:s}/lanenet_model.pb'.format(save_dir)) if weights_path is None: tf.logging.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: tf.logging.info( 'Restore model from last model checkpoint {:s}'.format( weights_path)) init_saver.restore(sess=sess, save_path=weights_path) assign_op = global_step.assign(0) sess.run(assign_op) # 加载预训练参数 if net_flag == 'vgg' and weights_path is None: pretrained_weights = np.load('./data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights_key = vv.name.split('/')[-3] try: weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue train_cost_time_mean = [] val_cost_time_mean = [] ignore_label_mask = cv2.imread(ignore_labels_path) last_c = 100000 for epoch in range(train_epochs): # training part t_start = time.time() with tf.device('/cpu:0'): gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE, ignore_label_mask=ignore_label_mask, ignore_label=255) # gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs] gt_imgs = [tmp / 128.0 - 1.0 for tmp in gt_imgs] binary_gt_labels = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels ] phase_train = 'train' _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img, g_step = \ sess.run([optimizer, total_loss, accuracy, train_merge_summary_op, binary_seg_loss, disc_loss, pix_embedding, out_logits_out, global_step], feed_dict={input_tensor: gt_imgs, binary_label_tensor: binary_gt_labels, instance_label_tensor: instance_gt_labels, phase: phase_train}) # if epoch % 10 == 0: # tf.logging.info("Epoch {}." # "Total loss: {}. Train acc: {}." # " Binary loss: {}. Instance loss: {}".format(epoch, c, train_accuracy, # binary_loss, instance_loss)) if math.isnan(c) or math.isnan(binary_loss) or math.isnan( instance_loss): tf.logging.error('cost is: {:.5f}'.format(c)) tf.logging.error('binary cost is: {:.5f}'.format(binary_loss)) tf.logging.error( 'instance cost is: {:.5f}'.format(instance_loss)) # cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('nan_image.png', (gt_imgs[0] + 1.0) * 128) cv2.imwrite('nan_instance_label.png', instance_gt_labels[0]) cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255) return if epoch % 100 == 0: # cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('image.png', (gt_imgs[0] + 1.0) * 128) cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255) cv2.imwrite('instance_label.png', instance_gt_labels[0]) cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255) for i in range(4): embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i]) embedding_image = np.array(embedding[0], np.uint8) cv2.imwrite('embedding.png', embedding_image[:, :, :-1]) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) summary_writer.add_summary(summary=train_summary, global_step=epoch) # validation part with tf.device('/cpu:0'): gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \ = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE, ignore_label_mask=ignore_label_mask) # gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val] gt_imgs_val = [tmp / 128.0 - 1.0 for tmp in gt_imgs_val] binary_gt_labels_val = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val ] phase_val = 'test' t_start_val = time.time() c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \ sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss], feed_dict={input_tensor: gt_imgs_val, binary_label_tensor: binary_gt_labels_val, instance_label_tensor: instance_gt_labels_val, phase: phase_val}) if epoch % 100 == 0: # cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN) cv2.imwrite('test_image.png', (gt_imgs_val[0] + 1.0) * 128) summary_writer.add_summary(val_summary, global_step=epoch) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: tf.logging.info( 'Step: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}' ' mean_cost_time= {:5f}s '.format( epoch + 1, c, binary_loss, instance_loss, train_accuracy, np.mean(train_cost_time_mean))) train_cost_time_mean.clear() if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0: tf.logging.info( 'Step_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} ' 'mean_cost_time= {:5f}s '.format( epoch + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy, np.mean(val_cost_time_mean))) val_cost_time_mean.clear() if epoch % 2000 == 0: iter_saver.save(sess=sess, save_path=model_save_path, global_step=epoch) if c < last_c: last_c = c save_dir_best = save_dir + "/best" if not ops.exists(save_dir_best): os.makedirs(save_dir_best) best_model_save_path = ops.join(save_dir_best, model_name) best_saver.save(sess=sess, save_path=best_model_save_path, global_step=epoch) sess.close() return
def train_net(dataset_dir, weights_path=None, net_flag='shuffle'): """ :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ train_dataset_file = ops.join(dataset_dir, 'train.txt') val_dataset_file = ops.join(dataset_dir, 'val.txt') assert ops.exists(train_dataset_file) train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) with tf.device('/gpu:0'): print("gpu enableing...") #训练灰度图像shape要改成1吧 input_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3 ], name='input_tensor') binary_label_tensor = tf.placeholder(dtype=tf.int64, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1 ], name='binary_input_label') instance_label_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH ], name='instance_input_label') phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') net = shufflenet_merge_model.Shuffle_LaneNet(net_flag=net_flag, phase=phase) # calculate the loss compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor, instance_label=instance_label_tensor, name='lanenet_model') total_loss = compute_ret['total_loss'] binary_seg_loss = compute_ret['binary_seg_loss'] disc_loss = compute_ret['discriminative_loss'] pix_embedding = compute_ret['instance_seg_logits'] # calculate the accuracy out_logits = compute_ret['binary_seg_logits'] out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax(out_logits, axis=-1) out = tf.argmax(out_logits, axis=-1) out = tf.expand_dims(out, axis=-1) idx = tf.where(tf.equal(binary_label_tensor, 1)) pix_cls_ret = tf.gather_nd(out, idx) accuracy = tf.count_nonzero(pix_cls_ret) accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64)) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step, 100000, 0.1, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9).minimize(loss=total_loss, var_list=tf.trainable_variables(), global_step=global_step) # Set tf saver saver = tf.train.Saver() # 确定权重存储路径 if net_flag == 'vgg': model_save_dir = 'model/vgg/dvs' if net_flag == 'shuffle': model_save_dir = 'model/shufflenet/dvs' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'dvs_lanenet_{:s}_{:s}.ckpt'.format(net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # Set tf summary tboard_save_path = 'tboard/tusimple_lanenet/{:s}'.format(net_flag) if not ops.exists(tboard_save_path): os.makedirs(tboard_save_path) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss) val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=binary_seg_loss) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=disc_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=disc_loss) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_cost_scalar, learning_rate_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar ]) val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar ]) # Set sess configuration sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # Set the training parameters train_epochs = CFG.TRAIN.EPOCHS log.info('Global configuration is as follows:') log.info(CFG) with sess.as_default(): tf.train.write_graph( graph_or_graph_def=sess.graph, logdir='', name='{:s}/lanenet_model.pb'.format(model_save_dir)) if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) # 加载预训练参数 if net_flag == 'vgg' and weights_path is None: print('test') pretrained_weights = np.load('./data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights_key = vv.name.split('/')[-3] try: weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue if net_flag == 'shuffle' and weights_path is None: variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) try: print("Loading ImageNet pretrained weights...") dict = load_obj('./data/shufflenet_weights.pkl') run_list = [] for variable in variables: for key, value in dict.items(): # Adding ':' means that we are interested in the variable itself and not the variable parameters # that are used in adaptive optimizers if key + ":" in variable.name: run_list.append(tf.assign(variable, value)) sess.run(run_list) print("ImageNet Pretrained Weights Loaded Initially\n\n") except KeyboardInterrupt: print("No pretrained ImageNet weights exist. Skipping...\n\n") # 确定预通道参数 if net_flag == 'vgg': MEAN = VGG_MEAN if net_flag == 'shuffle': MEAN = SHUFFLE_MEAN train_cost_time_mean = [] val_cost_time_mean = [] for epoch in range(train_epochs): # training part t_start = time.time() with tf.device('/cpu:0'): gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) gt_imgs = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs ] gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs] binary_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in binary_gt_labels ] binary_gt_labels = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels ] instance_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels ] phase_train = 'train' _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \ sess.run([optimizer, total_loss, accuracy, train_merge_summary_op, binary_seg_loss, disc_loss, pix_embedding, out_logits_out], feed_dict={input_tensor: gt_imgs, binary_label_tensor: binary_gt_labels, instance_label_tensor: instance_gt_labels, phase: phase_train}) if math.isnan(c) or math.isnan(binary_loss) or math.isnan( instance_loss): log.error('cost is: {:.5f}'.format(c)) log.error('binary cost is: {:.5f}'.format(binary_loss)) log.error('instance cost is: {:.5f}'.format(instance_loss)) cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('nan_instance_label.png', instance_gt_labels[0]) cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255) return if epoch % 100 == 0: cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255) cv2.imwrite('instance_label.png', instance_gt_labels[0]) cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255) for i in range(4): embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i]) embedding_image = np.array(embedding[0], np.uint8) cv2.imwrite('embedding.png', embedding_image) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) summary_writer.add_summary(summary=train_summary, global_step=epoch) # validation part with tf.device('/cpu:0'): gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \ = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE) gt_imgs_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs_val ] gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val] binary_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp) for tmp in binary_gt_labels_val ] binary_gt_labels_val = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val ] instance_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels_val ] phase_val = 'test' t_start_val = time.time() c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \ sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss], feed_dict={input_tensor: gt_imgs_val, binary_label_tensor: binary_gt_labels_val, instance_label_tensor: instance_gt_labels_val, phase: phase_val}) if epoch % 100 == 0: cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN) summary_writer.add_summary(val_summary, global_step=epoch) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: log.info( 'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}' ' mean_cost_time= {:5f}s '.format( epoch + 1, c, binary_loss, instance_loss, train_accuracy, np.mean(train_cost_time_mean))) train_cost_time_mean.clear() if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0: log.info( 'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} ' 'mean_cost_time= {:5f}s '.format( epoch + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy, np.mean(val_cost_time_mean))) val_cost_time_mean.clear() if epoch % 10000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=epoch) sess.close() return
def train_net(dataset_dir, weights_path=None, net_flag='vgg', initial_step=0): """ :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ train_dataset_file = ops.join(dataset_dir, '7-3_random_train.txt') val_dataset_file = ops.join(dataset_dir, '7-3_random_val.txt') # train_dataset_file = ops.join(dataset_dir, '9-1_train.txt') # val_dataset_file = ops.join(dataset_dir, '9-1_val.txt') assert ops.exists(train_dataset_file) train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) with tf.device('/gpu:0'): # with tf.device('/cpu:0'): input_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3 ], name='input_tensor') binary_label_tensor = tf.placeholder(dtype=tf.int64, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1 ], name='binary_input_label') instance_label_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH ], name='instance_input_label') # binary_seg_img_tensor = tf.placeholder(dtype=tf.uint8, # shape=[CFG.TRAIN.IMG_HEIGHT, # CFG.TRAIN.IMG_WIDTH, 1]) phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase) # calculate the loss compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor, instance_label=instance_label_tensor, name='lanenet_model') total_loss = compute_ret['total_loss'] binary_seg_loss = compute_ret['binary_seg_loss'] disc_loss = compute_ret['discriminative_loss'] pix_embedding = compute_ret['instance_seg_logits'] counts = compute_ret['counts'] # calculate the accuracy out_logits = compute_ret['binary_seg_logits'] out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax( out_logits, axis=-1) # transform a 2-channel feature map into a binary image out = tf.argmax(out_logits, axis=-1) out = tf.expand_dims(out, axis=-1) idx = tf.where(tf.equal(binary_label_tensor, 1)) # select the Positive Pixels in GT image pix_cls_ret = tf.gather_nd( out, idx) # slice out the corresponding pixels in output image accuracy = tf.count_nonzero(pix_cls_ret) # True Positive accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64)) # Accuracy = TP / (TP + FN), ie. Recall global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step, CFG.TRAIN.LR_DECAY_STEPS, CFG.TRAIN.LR_DECAY_RATE, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9).minimize(loss=total_loss, var_list=tf.trainable_variables(), global_step=global_step) # optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) # Set tf saver saver = tf.train.Saver() model_save_dir = 'model/tusimple_lanenet' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format( net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) img_output_dir = f'output/{net_flag}_{train_start_time}' # Set tf restorer mobile_pretrained_path = 'model/mobilenet/mobilenet_v2_1.0_224.ckpt' reader = tf.train.NewCheckpointReader(mobile_pretrained_path) restore_dict = dict() for v in tf.trainable_variables(): s = v.name.split(':')[0] i = s.find('MobilenetV2') if i != -1: tensor_name = s[i:] # print(tensor_name) if reader.has_tensor(tensor_name): # print('has tensor ', tensor_name) restore_dict[tensor_name] = v pretrained_saver = tf.train.Saver(restore_dict, name="pretrained_saver") # Set tf summary tboard_save_path = f'tboard/tusimple_lanenet/{net_flag}/{train_start_time}' if not ops.exists(tboard_save_path): os.makedirs(tboard_save_path) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss) val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=binary_seg_loss) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=disc_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=disc_loss) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) # train_bin_seg_img = tf.summary.image('Train Binary Segmentation', tensor=binary_seg_img_tensor) # train_raw_img = tf.summary.image('Train Raw Image', gt_imgs[0] + VGG_MEAN) # val_bin_seg_img = tf.summary.image('Binary Segmentation', ) # val_bin_seg_img = tf.summary.image('Binary Segmentation', ) # val_bin_seg_img = tf.summary.image('Binary Segmentation', ) # val_bin_seg_img = tf.summary.image('Binary Segmentation', ) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_image.png', gt_imgs[0] + VGG_MEAN) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_binary_label.png', binary_gt_labels[0] * 255) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_instance_label.png', instance_gt_labels[0]) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_binary_seg_img.png', binary_seg_img[0] * 255) # # cv2.imwrite(f'output/{train_start_time}_{net_flag}_embedding.png', embedding_image) # # cv2.imwrite(f'output/{train_start_time}_{net_flag}_image_VAL.png', gt_imgs_val[0] + VGG_MEAN) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_binary_seg_img_VAL.png', val_binary_seg_img[0] * 255) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_cost_scalar, learning_rate_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar ]) # , train_bin_seg_img val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar ]) # Set sess configuration sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) # summary_writer = tf.summary.FileWriter(tboard_save_path) # summary_writer.add_graph(sess.graph) summary_writer = tf.summary.FileWriter(tboard_save_path, sess.graph) # Set the training parameters train_steps = CFG.TRAIN.STEPS log.info('Global configuration is as follows:') log.info(CFG) with sess.as_default(): tf.train.write_graph( graph_or_graph_def=sess.graph, logdir='', name='{:s}/lanenet_model.pb'.format(model_save_dir)) if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) # 加载预训练参数 if net_flag == 'vgg' and weights_path is None: pretrained_weights = np.load('./data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights_key = vv.name.split('/')[-3] try: weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue elif net_flag == 'mobile' and weights_path is None: pass # pretrained_saver.restore(sess=sess, save_path=mobile_pretrained_path) train_cost_time_mean = [] val_cost_time_mean = [] for step in range(int(initial_step), train_steps): # training part t_start = time.time() with tf.device('/gpu:0'): raw_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) # gt_imgs = [cv2.resize(tmp, # dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), # dst=tmp, # interpolation=cv2.INTER_LINEAR) # for tmp in gt_imgs] gt_imgs = [tmp - VGG_MEAN for tmp in raw_imgs] # binary_gt_labels = [cv2.resize(tmp, # dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), # dst=tmp, # interpolation=cv2.INTER_NEAREST) # for tmp in binary_gt_labels] binary_gt_labels = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels ] # instance_gt_labels = [cv2.resize(tmp, # dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), # dst=tmp, # interpolation=cv2.INTER_NEAREST) # for tmp in instance_gt_labels] phase_train = 'train' _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img, ct = \ sess.run([optimizer, total_loss, accuracy, train_merge_summary_op, binary_seg_loss, disc_loss, pix_embedding, out_logits_out, counts], feed_dict={input_tensor: gt_imgs, binary_label_tensor: binary_gt_labels, instance_label_tensor: instance_gt_labels, phase: phase_train}) # binary_label_tensor = tf.assign(tf.multiply(binary_seg_img[0], 255)) print(ct) if math.isnan(c) or math.isnan(binary_loss) or math.isnan( instance_loss): log.error('cost is: {:.5f}'.format(c)) log.error('binary cost is: {:.5f}'.format(binary_loss)) log.error('instance cost is: {:.5f}'.format(instance_loss)) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_nan_image.png', gt_imgs[0] + VGG_MEAN) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_nan_instance_label.png', instance_gt_labels[0]) # cv2.imwrite(f'output/{train_start_time}_{net_flag}_nan_binary_label.png', binary_gt_labels[0] * 255) return if step % 50 == 0: if not os.path.exists(img_output_dir): os.mkdir(img_output_dir) print("Image Updated...") cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_TRAIN_raw.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_TRAIN_binary_label.png', binary_gt_labels[0] * 255) cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_TRAIN_instance_label.png', instance_gt_labels[0]) cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_TRAIN_bin_seg.png', binary_seg_img[0] * 255) for i in range(4): embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i]) embedding_image = np.array(embedding[0], np.uint8) cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_TRAIN_embedding.png', embedding_image) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) summary_writer.add_summary(summary=train_summary, global_step=step) # validation part with tf.device('/gpu:0'): gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \ = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE) # gt_imgs_val = [cv2.resize(tmp, # dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), # dst=tmp, # interpolation=cv2.INTER_LINEAR) # for tmp in gt_imgs_val] gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val] # binary_gt_labels_val = [cv2.resize(tmp, # dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), # dst=tmp) # for tmp in binary_gt_labels_val] binary_gt_labels_val = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val ] # instance_gt_labels_val = [cv2.resize(tmp, # dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), # dst=tmp, # interpolation=cv2.INTER_NEAREST) # for tmp in instance_gt_labels_val] phase_val = 'test' t_start_val = time.time() c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss, embedding, val_binary_seg_img, val_ct = \ sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss, pix_embedding, out_logits_out, counts], feed_dict={input_tensor: gt_imgs_val, binary_label_tensor: binary_gt_labels_val, instance_label_tensor: instance_gt_labels_val, phase: phase_val}) if step % 50 == 0: if not os.path.exists(img_output_dir): os.mkdir(img_output_dir) for i in range(CFG.TRAIN.VAL_BATCH_SIZE): cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_VAL_{i}_raw.png', gt_imgs_val[i] + VGG_MEAN) cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_VAL_{i}_bin_seg.png', val_binary_seg_img[i] * 255) for j in range(4): embedding[i][:, :, j] = minmax_scale(embedding[i][:, :, j]) embedding_image = np.array(embedding[i], np.uint8) cv2.imwrite( img_output_dir + f'/{train_start_time}_{net_flag}_VAL_{i}_embedding.png', embedding_image) summary_writer.add_summary(val_summary, global_step=step) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) if step % CFG.TRAIN.DISPLAY_STEP == 0: log.info( 'Step: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}' ' mean_cost_time= {:5f}s '.format( step + 1, c, binary_loss, instance_loss, train_accuracy, np.mean(train_cost_time_mean))) train_cost_time_mean.clear() if step % CFG.TRAIN.TEST_DISPLAY_STEP == 0: log.info( 'Step_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} ' 'mean_cost_time= {:5f}s '.format( step + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy, np.mean(val_cost_time_mean))) val_cost_time_mean.clear() if step % 2000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=step) sess.close() return
def train_net(dataset_dir, weights_path=None, net_flag='vgg'): """ :param dataset_dir: :param net_flag: choose which base network to use :param weights_path: :return: """ # 所有训练样本列表 train_dataset_file = ops.join(dataset_dir, 'train.txt') val_dataset_file = ops.join(dataset_dir, 'val.txt') assert ops.exists(train_dataset_file) train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) input_tensor = tf.placeholder(dtype=tf.float32, shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3], name='input_tensor') binary_label_tensor = tf.placeholder(dtype=tf.int64, shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1], name='binary_input_label') instance_label_tensor = tf.placeholder(dtype=tf.float32, shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH], name='instance_input_label') phase = tf.placeholder(dtype=tf.bool, shape=None, name='net_phase') net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase) # calculate the loss compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor, instance_label=instance_label_tensor, name='lanenet_model') total_loss = compute_ret['total_loss'] binary_seg_loss = compute_ret['binary_seg_loss'] disc_loss = compute_ret['discriminative_loss'] pix_embedding = compute_ret['instance_seg_logits'] # calculate the accuracy out_logits = compute_ret['binary_seg_logits'] out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax(out_logits, axis=-1) out = tf.argmax(out_logits, axis=-1) out = tf.expand_dims(out, axis=-1) idx = tf.where(tf.equal(binary_label_tensor, 1)) pix_cls_ret = tf.gather_nd(out, idx) recall = tf.count_nonzero(pix_cls_ret) recall = tf.divide(recall, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64)) idx = tf.where(tf.equal(binary_label_tensor, 0)) pix_cls_ret = tf.gather_nd(out, idx) precision = tf.subtract(tf.cast(tf.shape(pix_cls_ret)[0], tf.int64), tf.count_nonzero(pix_cls_ret)) precision = tf.divide(precision, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64)) accuracy = tf.divide(2.0, tf.divide(1.0, recall) + tf.divide(1.0, precision)) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step, 100000, 0.1, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) gradients = optimizer.compute_gradients(total_loss) capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None] train_op = optimizer.apply_gradients(capped_gradients, global_step=global_step) # Set tf saver saver = tf.train.Saver() model_save_dir = 'model/tusimple_lanenet' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(net_flag, str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # Set tf summary tboard_save_path = 'tboard/tusimple_lanenet/{:s}'.format(net_flag) if not ops.exists(tboard_save_path): os.makedirs(tboard_save_path) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss) val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy) train_binary_seg_loss_scalar = tf.summary.scalar(name='train_binary_seg_loss', tensor=binary_seg_loss) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar(name='train_instance_seg_loss', tensor=disc_loss) val_instance_seg_loss_scalar = tf.summary.scalar(name='val_instance_seg_loss', tensor=disc_loss) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) train_merge_summary_op = tf.summary.merge([train_accuracy_scalar, train_cost_scalar, learning_rate_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar]) val_merge_summary_op = tf.summary.merge([val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar]) # Set sess configuration sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # Set the training parameters train_epochs = CFG.TRAIN.EPOCHS log.info('Global configuration is as follows:') log.info(CFG) with sess.as_default(): tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='', name='{:s}/lanenet_model.pbtxt'.format(model_save_dir)) if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format(weights_path)) saver.restore(sess=sess, save_path=weights_path) # 加载预训练参数 if net_flag == 'vgg' and weights_path is None: pretrained_weights = np.load( './data/vgg16.npy', encoding='latin1').item() for vv in tf.trainable_variables(): weights_key = vv.name.split('/')[-3] try: weights = pretrained_weights[weights_key][0] _op = tf.assign(vv, weights) sess.run(_op) except Exception as e: continue train_cost_time_mean = [] for epoch in range(train_epochs): # training part t_start = time.time() gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch(CFG.TRAIN.BATCH_SIZE) gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs] _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \ sess.run([train_op, total_loss, accuracy, train_merge_summary_op, binary_seg_loss, disc_loss, pix_embedding, out_logits_out], feed_dict={input_tensor: gt_imgs, binary_label_tensor: binary_gt_labels, instance_label_tensor: instance_gt_labels, phase: True}) if math.isnan(c) or math.isnan(binary_loss) or math.isnan(instance_loss): log.error('cost is: {:.5f}'.format(c)) log.error('binary cost is: {:.5f}'.format(binary_loss)) log.error('instance cost is: {:.5f}'.format(instance_loss)) log.error('gradients is: {}'.format(g)) cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('nan_instance_label.png', instance_gt_labels[0]) cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255) return if epoch % 100 == 0: cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255) cv2.imwrite('instance_label.png', instance_gt_labels[0]) cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255) for i in range(4): embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i]) embedding_image = np.array(embedding[0], np.uint8) cv2.imwrite('embedding.png', embedding_image) cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) summary_writer.add_summary(summary=train_summary, global_step=epoch) if epoch % CFG.TRAIN.DISPLAY_STEP == 0: log.info('Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}' ' mean_cost_time= {:5f}s '. format(epoch + 1, c, binary_loss, instance_loss, train_accuracy, np.mean(train_cost_time_mean))) train_cost_time_mean = [] if epoch % 1000 == 0: saver.save(sess=sess, save_path=model_save_path, global_step=epoch) sess.close() return
def train_net(dataset_dir, weights_path=None): train_dataset_file = ops.join(dataset_dir, 'train.txt') val_dataset_file = ops.join(dataset_dir, 'val.txt') assert ops.exists(train_dataset_file), '{:s} 不存在'.format( train_dataset_file) assert ops.exists(val_dataset_file), '{:s} 不存在'.format(val_dataset_file) # 创建训练集和验证集实例train_dataset,val_dataset train_dataset = lanenet_data_processor.DataSet(train_dataset_file) val_dataset = lanenet_data_processor.DataSet(val_dataset_file) # Tensorflow的创建Graph过程 with tf.device('/gpu:0'): # input_tensor:输入张量,binary_label_tensor:二值分割标签,instance_label_tensor:实例分割标签,phase:训练(测试)阶段 input_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3 ], name='input_tensor') binary_label_tensor = tf.placeholder(dtype=tf.int64, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1 ], name='binary_input_label') instance_label_tensor = tf.placeholder(dtype=tf.float32, shape=[ CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH ], name='instance_input_label') phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase') # 创建LaneNet网络架构 net = lanenet_merge_model.LaneNet(phase=phase) # 计算损失 compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor, instance_label=instance_label_tensor, name='lanenet_model') total_loss = compute_ret['total_loss'] binary_seg_loss = compute_ret['binary_seg_loss'] disc_loss = compute_ret['discriminative_loss'] pix_embedding = compute_ret['instance_seg_logits'] # 计算准确度 out_logits = compute_ret['binary_seg_logits'] out_logits = tf.nn.softmax(logits=out_logits) out_logits_out = tf.argmax(out_logits, axis=-1) out = tf.argmax(out_logits, axis=-1) out = tf.expand_dims(out, axis=-1) idx = tf.where(tf.equal(binary_label_tensor, 1)) pix_cls_ret = tf.gather_nd(out, idx) accuracy = tf.count_nonzero(pix_cls_ret) accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64)) # 设置训练迭代步数,学习率以及优化器 global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step, CFG.TRAIN.LR_DECAY_STEPS, CFG.TRAIN.LR_DECAY_RATE, staircase=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=CFG.TRAIN.MOMENTUM).minimize( loss=total_loss, var_list=tf.trainable_variables(), global_step=global_step) # 设置Tensorflow的Saver,用以保存Model saver = tf.train.Saver() # 设置Model的保存目录 model_save_dir = 'model' if not ops.exists(model_save_dir): os.makedirs(model_save_dir) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) # 设置Model的名称(以训练开始时间为后缀) model_name = 'lanenet_{:s}.ckpt'.format(str(train_start_time)) model_save_path = ops.join(model_save_dir, model_name) # 设置Tensorflow的Summary,用以保存tboard # 设置tboard的目录(以训练开始时间为后缀) tboard_save_path = 'tboard/lanenet_{:s}'.format(str(train_start_time)) if not ops.exists(tboard_save_path): os.makedirs(tboard_save_path) train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss) val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss) train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy) val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy) train_binary_seg_loss_scalar = tf.summary.scalar( name='train_binary_seg_loss', tensor=binary_seg_loss) val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss) train_instance_seg_loss_scalar = tf.summary.scalar( name='train_instance_seg_loss', tensor=disc_loss) val_instance_seg_loss_scalar = tf.summary.scalar( name='val_instance_seg_loss', tensor=disc_loss) learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate) train_merge_summary_op = tf.summary.merge([ train_accuracy_scalar, train_cost_scalar, learning_rate_scalar, train_binary_seg_loss_scalar, train_instance_seg_loss_scalar ]) val_merge_summary_op = tf.summary.merge([ val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar, val_instance_seg_loss_scalar ]) # 设置Session的全局配置 sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH sess_config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=sess_config) summary_writer = tf.summary.FileWriter(tboard_save_path) summary_writer.add_graph(sess.graph) # 设置训练阶段的全局参数,并打印出来 train_epochs = CFG.TRAIN.EPOCHS log.info('Global configuration is as follows:') log.info(CFG) # Tensorflow的打开Session过程 with sess.as_default(): # 将Graph的信息保存在lanenet_model.pb文件中 tf.train.write_graph( graph_or_graph_def=sess.graph, logdir='', name='{:s}/lanenet_model.pb'.format(model_save_dir)) # 如果不存在预训练的模型,则初始化参数从头开始训练,否则加载预训练模型进行迁移学习 if weights_path is None: log.info('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: log.info('Restore model from last model checkpoint {:s}'.format( weights_path)) saver.restore(sess=sess, save_path=weights_path) train_cost_time_mean = [] val_cost_time_mean = [] for epoch in range(train_epochs): # 训练部分 t_start = time.time() with tf.device('/cpu:0'): # gt_imgs代表原图,binary_gt_labels代表二值分割标签,instance_gt_labels代表实例分割标签 gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch( CFG.TRAIN.BATCH_SIZE) gt_imgs = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs ] gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs] binary_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in binary_gt_labels ] binary_gt_labels = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels ] instance_gt_labels = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels ] phase_train = 'train' # 训练LaneNet网络 _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \ sess.run([optimizer, total_loss, accuracy, train_merge_summary_op, binary_seg_loss, disc_loss, pix_embedding, out_logits_out], feed_dict={input_tensor: gt_imgs, binary_label_tensor: binary_gt_labels, instance_label_tensor: instance_gt_labels, phase: phase_train}) # 异常处理:当损失不为数字时,打印异常并保存当前结果 if math.isnan(c) or math.isnan(binary_loss) or math.isnan( instance_loss): log.error('Epoch: {:d} Total cost: {:}'.format(epoch + 1, c)) log.error('Epoch: {:d} Total binary cost: {:}'.format( epoch + 1, binary_loss)) log.error('Epoch: {:d} Total instance cost: {:}'.format( epoch + 1, instance_loss)) cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN) cv2.imwrite('nan_instance_label.png', instance_gt_labels[0]) cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255) return cost_time = time.time() - t_start train_cost_time_mean.append(cost_time) # tboard记录训练日志 summary_writer.add_summary(summary=train_summary, global_step=epoch) # 验证部分 with tf.device('/cpu:0'): gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \ = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE) gt_imgs_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs_val ] gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val] binary_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp) for tmp in binary_gt_labels_val ] binary_gt_labels_val = [ np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val ] instance_gt_labels_val = [ cv2.resize(tmp, dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), dst=tmp, interpolation=cv2.INTER_NEAREST) for tmp in instance_gt_labels_val ] phase_val = 'test' t_start_val = time.time() # 验证LaneNet网络 c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \ sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss], feed_dict={input_tensor: gt_imgs_val, binary_label_tensor: binary_gt_labels_val, instance_label_tensor: instance_gt_labels_val, phase: phase_val}) # tboard记录验证日志 summary_writer.add_summary(val_summary, global_step=epoch) cost_time_val = time.time() - t_start_val val_cost_time_mean.append(cost_time_val) # 打印训练日志 if epoch % CFG.TRAIN.DISPLAY_STEP == 0: log.info( 'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}' ' mean_cost_time= {:5f}s '.format( epoch + 1, c, binary_loss, instance_loss, train_accuracy, np.mean(train_cost_time_mean))) train_cost_time_mean.clear() # 打印验证日志 if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0: log.info( 'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} ' 'instance_seg_loss= {:6f} accuracy= {:6f} ' 'mean_cost_time= {:5f}s '.format( epoch + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy, np.mean(val_cost_time_mean))) val_cost_time_mean.clear() # 保存Model if epoch % 2000 == 0 and epoch != 0: saver.save(sess=sess, save_path=model_save_path, global_step=epoch) # 关闭Session sess.close() return