예제 #1
0
def train_lanenet(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='val')

    with tf.device('/gpu:1'):
        # set lanenet
        train_net = lanenet.LaneNet(net_flag=net_flag,
                                    phase='train',
                                    reuse=False)
        val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True)

        # set compute graph node for training
        train_images, train_binary_labels, train_instance_labels = train_dataset.inputs(
            CFG.TRAIN.BATCH_SIZE, 1)

        train_compute_ret = train_net.compute_loss(
            input_tensor=train_images,
            binary_label=train_binary_labels,
            instance_label=train_instance_labels,
            name='lanenet_model')
        train_total_loss = train_compute_ret['total_loss']
        train_binary_seg_loss = train_compute_ret['binary_seg_loss']
        train_disc_loss = train_compute_ret['discriminative_loss']
        train_pix_embedding = train_compute_ret['instance_seg_logits']

        train_prediction_logits = train_compute_ret['binary_seg_logits']
        train_prediction_score = tf.nn.softmax(logits=train_prediction_logits)
        train_prediction = tf.argmax(train_prediction_score, axis=-1)

        train_accuracy = evaluate_model_utils.calculate_model_precision(
            train_compute_ret['binary_seg_logits'], train_binary_labels)
        train_fp = evaluate_model_utils.calculate_model_fp(
            train_compute_ret['binary_seg_logits'], train_binary_labels)
        train_fn = evaluate_model_utils.calculate_model_fn(
            train_compute_ret['binary_seg_logits'], train_binary_labels)
        train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=train_prediction)
        train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=train_pix_embedding)

        train_cost_scalar = tf.summary.scalar(name='train_cost',
                                              tensor=train_total_loss)
        train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                                  tensor=train_accuracy)
        train_binary_seg_loss_scalar = tf.summary.scalar(
            name='train_binary_seg_loss', tensor=train_binary_seg_loss)
        train_instance_seg_loss_scalar = tf.summary.scalar(
            name='train_instance_seg_loss', tensor=train_disc_loss)
        train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn)
        train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp)
        train_binary_seg_ret_img = tf.summary.image(
            name='train_binary_seg_ret',
            tensor=train_binary_seg_ret_for_summary)
        train_embedding_feats_ret_img = tf.summary.image(
            name='train_embedding_feats_ret',
            tensor=train_embedding_ret_for_summary)
        train_merge_summary_op = tf.summary.merge([
            train_accuracy_scalar, train_cost_scalar,
            train_binary_seg_loss_scalar, train_instance_seg_loss_scalar,
            train_fn_scalar, train_fp_scalar, train_binary_seg_ret_img,
            train_embedding_feats_ret_img
        ])

        # set compute graph node for validation
        val_images, val_binary_labels, val_instance_labels = val_dataset.inputs(
            CFG.TRAIN.VAL_BATCH_SIZE, 1)

        val_compute_ret = val_net.compute_loss(
            input_tensor=val_images,
            binary_label=val_binary_labels,
            instance_label=val_instance_labels,
            name='lanenet_model')
        val_total_loss = val_compute_ret['total_loss']
        val_binary_seg_loss = val_compute_ret['binary_seg_loss']
        val_disc_loss = val_compute_ret['discriminative_loss']
        val_pix_embedding = val_compute_ret['instance_seg_logits']

        val_prediction_logits = val_compute_ret['binary_seg_logits']
        val_prediction_score = tf.nn.softmax(logits=val_prediction_logits)
        val_prediction = tf.argmax(val_prediction_score, axis=-1)

        val_accuracy = evaluate_model_utils.calculate_model_precision(
            val_compute_ret['binary_seg_logits'], val_binary_labels)
        val_fp = evaluate_model_utils.calculate_model_fp(
            val_compute_ret['binary_seg_logits'], val_binary_labels)
        val_fn = evaluate_model_utils.calculate_model_fn(
            val_compute_ret['binary_seg_logits'], val_binary_labels)
        val_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=val_prediction)
        val_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=val_pix_embedding)

        val_cost_scalar = tf.summary.scalar(name='val_cost',
                                            tensor=val_total_loss)
        val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                                tensor=val_accuracy)
        val_binary_seg_loss_scalar = tf.summary.scalar(
            name='val_binary_seg_loss', tensor=val_binary_seg_loss)
        val_instance_seg_loss_scalar = tf.summary.scalar(
            name='val_instance_seg_loss', tensor=val_disc_loss)
        val_fn_scalar = tf.summary.scalar(name='val_fn', tensor=val_fn)
        val_fp_scalar = tf.summary.scalar(name='val_fp', tensor=val_fp)
        val_binary_seg_ret_img = tf.summary.image(
            name='val_binary_seg_ret', tensor=val_binary_seg_ret_for_summary)
        val_embedding_feats_ret_img = tf.summary.image(
            name='val_embedding_feats_ret',
            tensor=val_embedding_ret_for_summary)
        val_merge_summary_op = tf.summary.merge([
            val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
            val_instance_seg_loss_scalar, val_fn_scalar, val_fp_scalar,
            val_binary_seg_ret_img, val_embedding_feats_ret_img
        ])

        # set optimizer
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.polynomial_decay(
            learning_rate=CFG.TRAIN.LEARNING_RATE,
            global_step=global_step,
            decay_steps=CFG.TRAIN.EPOCHS,
            power=0.9)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=CFG.TRAIN.MOMENTUM).minimize(
                    loss=train_total_loss,
                    var_list=tf.trainable_variables(),
                    global_step=global_step)

    # Set tf model save path
    model_save_dir = 'model/tusimple_lanenet_{:s}'.format(net_flag)
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)
    saver = tf.train.Saver()

    # Set tf summary save path
    tboard_save_path = 'tboard/tusimple_lanenet_{:s}'.format(net_flag)
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        if net_flag == 'vgg' and weights_path is None:
            load_pretrained_weights(tf.trainable_variables(),
                                    './data/vgg16.npy', sess)

        train_cost_time_mean = []
        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            _, train_c, train_accuracy_figure, train_fn_figure, train_fp_figure, lr, train_summary, train_binary_loss, \
            train_instance_loss, train_embeddings, train_binary_seg_imgs, train_gt_imgs, \
            train_binary_gt_labels, train_instance_gt_labels = \
                sess.run([optimizer, train_total_loss, train_accuracy, train_fn, train_fp,
                          learning_rate, train_merge_summary_op, train_binary_seg_loss,
                          train_disc_loss, train_pix_embedding, train_prediction,
                          train_images, train_binary_labels, train_instance_labels])

            if math.isnan(train_c) or math.isnan(
                    train_binary_loss) or math.isnan(train_instance_loss):
                log.error('cost is: {:.5f}'.format(train_c))
                log.error('binary cost is: {:.5f}'.format(train_binary_loss))
                log.error(
                    'instance cost is: {:.5f}'.format(train_instance_loss))
                return

            if epoch % 100 == 0:
                record_training_intermediate_result(
                    gt_images=train_gt_imgs,
                    gt_binary_labels=train_binary_gt_labels,
                    gt_instance_labels=train_instance_gt_labels,
                    binary_seg_images=train_binary_seg_imgs,
                    pix_embeddings=train_embeddings)
            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}'
                    ' lr= {:6f} mean_cost_time= {:5f}s '.format(
                        epoch + 1, train_c, train_binary_loss,
                        train_instance_loss, train_accuracy_figure,
                        train_fp_figure, train_fn_figure, lr,
                        np.mean(train_cost_time_mean)))
                del train_cost_time_mean[:]

            # validation part
            val_c, val_accuracy_figure, val_fn_figure, val_fp_figure, val_summary, val_binary_loss, \
            val_instance_loss, val_embeddings, val_binary_seg_imgs, val_gt_imgs, \
            val_binary_gt_labels, val_instance_gt_labels = \
                sess.run([val_total_loss, val_accuracy, val_fn, val_fp,
                          val_merge_summary_op, val_binary_seg_loss,
                          val_disc_loss, val_pix_embedding, val_prediction,
                          val_images, val_binary_labels, val_instance_labels])

            if math.isnan(val_c) or math.isnan(val_binary_loss) or math.isnan(
                    val_instance_loss):
                log.error('cost is: {:.5f}'.format(val_c))
                log.error('binary cost is: {:.5f}'.format(val_binary_loss))
                log.error('instance cost is: {:.5f}'.format(val_instance_loss))
                return

            if epoch % 100 == 0:
                record_training_intermediate_result(
                    gt_images=val_gt_imgs,
                    gt_binary_labels=val_binary_gt_labels,
                    gt_instance_labels=val_instance_gt_labels,
                    binary_seg_images=val_binary_seg_imgs,
                    pix_embeddings=val_embeddings,
                    flag='val')

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=val_summary, global_step=epoch)

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                log.info(
                    'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        epoch + 1, val_c, val_binary_loss, val_instance_loss,
                        val_accuracy_figure, val_fp_figure, val_fn_figure,
                        np.mean(train_cost_time_mean)))
                del train_cost_time_mean[:]

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=global_step)

    return
예제 #2
0
    train_total_loss = train_compute_ret['total_loss']
    train_binary_seg_loss = train_compute_ret['binary_seg_loss']
    train_disc_loss = train_compute_ret['discriminative_loss']
    train_pix_embedding = train_compute_ret['instance_seg_logits']

    train_prediction_logits = train_compute_ret['binary_seg_logits']
    train_prediction_score = tf.nn.softmax(logits=train_prediction_logits)
    train_prediction = tf.argmax(train_prediction_score, axis=-1)

    train_accuracy = evaluate_model_utils.calculate_model_precision(
        train_compute_ret['binary_seg_logits'], train_binary_labels)
    train_fp = evaluate_model_utils.calculate_model_fp(
        train_compute_ret['binary_seg_logits'], train_binary_labels)
    train_fn = evaluate_model_utils.calculate_model_fn(
        train_compute_ret['binary_seg_logits'], train_binary_labels)
    train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=train_prediction)
    train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=train_pix_embedding)

    train_cost_scalar = tf.summary.scalar(name='train_cost',
                                          tensor=train_total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=train_accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=train_binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=train_disc_loss)
    train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn)
    train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp)
    train_binary_seg_ret_img = tf.summary.image(
        name='train_binary_seg_ret', tensor=train_binary_seg_ret_for_summary)
def train_lanenet(weights_path=None,
                  net_flag='vgg',
                  version_flag='',
                  scratch=False):
    """
    :param weights_path:
    :param net_flag: choose which base network to use
    :param version_flag: exp flag
    :return:
    """
    # ========================== placeholder ========================= #
    with tf.name_scope('train_input'):
        train_input_tensor = tf.placeholder(dtype=tf.float32,
                                            name='input_image',
                                            shape=[None, None, None, 3])
        train_binary_label_tensor = tf.placeholder(dtype=tf.float32,
                                                   name='binary_input_label',
                                                   shape=[None, None, None, 1])
        train_instance_label_tensor = tf.placeholder(
            dtype=tf.float32,
            name='instance_input_label',
            shape=[None, None, None, 1])

    with tf.name_scope('val_input'):
        val_input_tensor = tf.placeholder(dtype=tf.float32,
                                          name='input_image',
                                          shape=[None, None, None, 3])
        val_binary_label_tensor = tf.placeholder(dtype=tf.float32,
                                                 name='binary_input_label',
                                                 shape=[None, None, None, 1])
        val_instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                                   name='instance_input_label',
                                                   shape=[None, None, None, 1])

    # ================================================================ #
    #                           Define Network                         #
    # ================================================================ #
    train_net = lanenet.LaneNet(net_flag=net_flag,
                                phase='train',
                                reuse=tf.AUTO_REUSE)
    val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True)
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                       Train Input & Output                       #
    # ================================================================ #
    trainset = DataSet('train')
    # trainset = MergeDataSet('train_lane')
    train_compute_ret = train_net.compute_loss(
        input_tensor=train_input_tensor,
        binary_label=train_binary_label_tensor,
        instance_label=train_instance_label_tensor,
        name='lanenet_model')
    train_total_loss = train_compute_ret['total_loss']
    train_binary_seg_loss = train_compute_ret['binary_seg_loss']  # 语义分割 loss
    train_disc_loss = train_compute_ret[
        'discriminative_loss']  # embedding loss
    train_pix_embedding = train_compute_ret[
        'instance_seg_logits']  # embedding feature, HxWxN
    train_l2_reg_loss = train_compute_ret['l2_reg_loss']

    train_prediction_logits = train_compute_ret[
        'binary_seg_logits']  # 语义分割结果,HxWx2
    train_prediction_score = tf.nn.softmax(logits=train_prediction_logits)
    train_prediction = tf.argmax(train_prediction_score, axis=-1)  # 语义分割二值图

    train_accuracy = evaluate_model_utils.calculate_model_precision(
        train_compute_ret['binary_seg_logits'], train_binary_label_tensor)
    train_fp = evaluate_model_utils.calculate_model_fp(
        train_compute_ret['binary_seg_logits'], train_binary_label_tensor)
    train_fn = evaluate_model_utils.calculate_model_fn(
        train_compute_ret['binary_seg_logits'], train_binary_label_tensor)
    train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=train_prediction)  # (I - min) * 255 / (max -min), 归一化到0-255
    train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=train_pix_embedding)  # (I - min) * 255 / (max -min), 归一化到0-255
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                          Define Optimizer                        #
    # ================================================================ #
    # set optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    # learning_rate = tf.train.cosine_decay_restarts( # 余弦衰减
    #     learning_rate=cfg.TRAIN.LEARNING_RATE,      # 初始学习率
    #     global_step=global_step,                    # 当前迭代次数
    #     first_decay_steps=cfg.TRAIN.STEPS/3,        # 首次衰减周期
    #     t_mul=2.0,                                  # 随后每次衰减周期倍数
    #     m_mul=1.0,                                  # 随后每次初始学习率倍数
    #     alpha = 0.1,                                # 最小的学习率=alpha*learning_rate
    # )
    learning_rate = tf.train.polynomial_decay(  # 多项式衰减
        learning_rate=cfg.TRAIN.LEARNING_RATE,  # 初始学习率
        global_step=global_step,  # 当前迭代次数
        decay_steps=cfg.TRAIN.STEPS /
        4,  # 在迭代到该次数实际,学习率衰减为 learning_rate * dacay_rate
        end_learning_rate=cfg.TRAIN.LEARNING_RATE / 10,  # 最小的学习率
        power=0.9,
        cycle=True)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate',
                                             tensor=learning_rate)
    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # for batch normalization
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.MomentumOptimizer(
            learning_rate=learning_rate, momentum=cfg.TRAIN.MOMENTUM).minimize(
                loss=train_total_loss,
                var_list=tf.trainable_variables(),
                global_step=global_step)
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                           Train Summary                          #
    # ================================================================ #
    train_loss_scalar = tf.summary.scalar(name='train_cost',
                                          tensor=train_total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=train_accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=train_binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=train_disc_loss)
    train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn)
    train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp)
    train_binary_seg_ret_img = tf.summary.image(
        name='train_binary_seg_ret', tensor=train_binary_seg_ret_for_summary)
    train_embedding_feats_ret_img = tf.summary.image(
        name='train_embedding_feats_ret',
        tensor=train_embedding_ret_for_summary)
    train_merge_summary_op = tf.summary.merge([
        train_accuracy_scalar, train_loss_scalar, train_binary_seg_loss_scalar,
        train_instance_seg_loss_scalar, train_fn_scalar, train_fp_scalar,
        train_binary_seg_ret_img, train_embedding_feats_ret_img,
        learning_rate_scalar
    ])
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                        Val Input & Output                        #
    # ================================================================ #
    valset = DataSet('val', net_flag)
    # valset = MergeDataSet('test_lane')
    val_compute_ret = val_net.compute_loss(
        input_tensor=val_input_tensor,
        binary_label=val_binary_label_tensor,
        instance_label=val_instance_label_tensor,
        name='lanenet_model')
    val_total_loss = val_compute_ret['total_loss']
    val_binary_seg_loss = val_compute_ret['binary_seg_loss']
    val_disc_loss = val_compute_ret['discriminative_loss']
    val_pix_embedding = val_compute_ret['instance_seg_logits']

    val_prediction_logits = val_compute_ret['binary_seg_logits']
    val_prediction_score = tf.nn.softmax(logits=val_prediction_logits)
    val_prediction = tf.argmax(val_prediction_score, axis=-1)

    val_accuracy = evaluate_model_utils.calculate_model_precision(
        val_compute_ret['binary_seg_logits'], val_binary_label_tensor)
    val_fp = evaluate_model_utils.calculate_model_fp(
        val_compute_ret['binary_seg_logits'], val_binary_label_tensor)
    val_fn = evaluate_model_utils.calculate_model_fn(
        val_compute_ret['binary_seg_logits'], val_binary_label_tensor)
    val_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=val_prediction)
    val_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=val_pix_embedding)
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                            VAL Summary                           #
    # ================================================================ #
    val_loss_scalar = tf.summary.scalar(name='val_cost', tensor=val_total_loss)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                            tensor=val_accuracy)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss',
                                                   tensor=val_binary_seg_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(
        name='val_instance_seg_loss', tensor=val_disc_loss)
    val_fn_scalar = tf.summary.scalar(name='val_fn', tensor=val_fn)
    val_fp_scalar = tf.summary.scalar(name='val_fp', tensor=val_fp)
    val_binary_seg_ret_img = tf.summary.image(
        name='val_binary_seg_ret', tensor=val_binary_seg_ret_for_summary)
    val_embedding_feats_ret_img = tf.summary.image(
        name='val_embedding_feats_ret', tensor=val_embedding_ret_for_summary)
    val_merge_summary_op = tf.summary.merge([
        val_accuracy_scalar, val_loss_scalar, val_binary_seg_loss_scalar,
        val_instance_seg_loss_scalar, val_fn_scalar, val_fp_scalar,
        val_binary_seg_ret_img, val_embedding_feats_ret_img
    ])
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                      Config Saver & Session                      #
    # ================================================================ #
    # Set tf model save path
    model_save_dir = 'model/tusimple_lanenet_{:s}_{:s}'.format(
        net_flag, version_flag)
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # ==============================
    if scratch:
        """
        删除 Momentum 的参数, 注意这里保存的 meta 文件也会删了
        tensorflow 在 save model 的时候,如果选择了 global_step 选项,会 global_step 值也保存下来,
        然后 restore 的时候也就会接着这个 global_step 继续训练下去,因此需要去掉
        """
        variables = tf.contrib.framework.get_variables_to_restore()
        variables_to_resotre = [
            v for v in variables if 'Momentum' not in v.name.split('/')[-1]
        ]
        variables_to_resotre = [
            v for v in variables_to_resotre
            if 'global_step' not in v.name.split('/')[-1]
        ]  # remove global step
        restore_saver = tf.train.Saver(variables_to_resotre)
    else:
        restore_saver = tf.train.Saver()
    saver = tf.train.Saver(max_to_keep=10)
    # ==============================

    # Set tf summary save path
    tboard_save_path = 'tboard/tusimple_lanenet_{:s}_{:s}'.format(
        net_flag, version_flag)
    os.makedirs(tboard_save_path, exist_ok=True)

    # Set sess configuration
    # ============================== config GPU
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    # sess_config.gpu_options.per_process_gpu_memory_fraction = cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = cfg.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'
    # ==============================
    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)
    # ---------------------------------------------------------------- #

    # Set the training parameters
    import math
    one_epoch2step = math.ceil(cfg.TRAIN.TRAIN_SIZE /
                               cfg.TRAIN.BATCH_SIZE)  # 训练一个 epoch 需要的 batch 数量
    total_epoch = math.ceil(cfg.TRAIN.STEPS / one_epoch2step)  # 一共需要训练多少 epoch

    log.info('Global configuration is as follows:')
    log.info(cfg)
    max_acc = 0.9
    save_num = 0
    val_step = 0
    # ================================================================ #
    #                            Train & Val                           #
    # ================================================================ #
    with sess.as_default():
        # ============================== load pretrain model
        # if weights_path is None:
        #     log.info('Training from scratch')
        #     sess.run(tf.global_variables_initializer())
        # elif net_flag == 'vgg' and weights_path is None:
        #     load_pretrained_weights(tf.trainable_variables(), './data/vgg16.npy', sess)
        # elif scratch: # 从头开始训练,类似 Caffe 的 --weights
        #     sess.run(tf.global_variables_initializer())
        #     log.info('Restore model from last model checkpoint {:s}, scratch'.format(weights_path))
        #     try:
        #         restore_saver.restore(sess=sess, save_path=weights_path)
        #     except:
        #         log.info('model maybe is not exist!')
        # else: # 继续训练,类似 Caffe 的 --snapshot
        #     log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
        #     try:
        #         restore_saver.restore(sess=sess, save_path=weights_path)
        #     except:
        #         log.info('model maybe is not exist!')
        sess.run(tf.global_variables_initializer())
        # ==============================
        for epoch in range(total_epoch):
            # ================================================================ #
            #                               Train                              #
            # ================================================================ #
            train_epoch_loss = []
            pbar_train = tqdm(trainset)
            train_t_start = time.time()
            for gt_imgs, binary_gt_labels, instance_gt_labels in pbar_train:
                _, global_step_val, train_loss, train_accuracy_figure, train_fn_figure, train_fp_figure, \
                lr, train_summary, train_binary_loss, train_instance_loss, \
                train_embeddings, train_binary_seg_imgs, train_l2_loss = \
                    sess.run([optimizer, global_step, train_total_loss, train_accuracy, train_fn, train_fp,
                              learning_rate, train_merge_summary_op, train_binary_seg_loss,
                              train_disc_loss, train_pix_embedding, train_prediction, train_l2_reg_loss],
                             feed_dict={train_input_tensor: gt_imgs,
                                        train_binary_label_tensor: binary_gt_labels,
                                        train_instance_label_tensor: instance_gt_labels}
                             )
                # ============================== 透心凉,心飞扬
                if math.isnan(train_loss) or math.isnan(
                        train_binary_loss) or math.isnan(train_instance_loss):
                    log.error('cost is: {:.5f}'.format(train_loss))
                    log.error(
                        'binary cost is: {:.5f}'.format(train_binary_loss))
                    log.error(
                        'instance cost is: {:.5f}'.format(train_instance_loss))
                    return
                # ==============================
                train_epoch_loss.append(train_loss)
                summary_writer.add_summary(summary=train_summary,
                                           global_step=global_step_val)
                pbar_train.set_description(
                    ("train loss: %.4f, learn rate: %e") % (train_loss, lr))
            train_cost_time = time.time() - train_t_start
            mean_train_loss = np.mean(train_epoch_loss)
            log.info(
                'MEAN Train: total_loss= {:6f} mean_cost_time= {:5f}s'.format(
                    mean_train_loss, train_cost_time))
            # ---------------------------------------------------------------- #

            # ================================================================ #
            #                                Val                               #
            # ================================================================ #
            # 每隔 epoch 次,测试整个验证集
            pbar_val = tqdm(valset)
            val_epoch_loss = []
            val_epoch_binary_loss = []
            val_epoch_instance_loss = []
            val_epoch_accuracy_figure = []
            val_epoch_fp_figure = []
            val_epoch_fn_figure = []
            val_t_start = time.time()
            for val_images, val_binary_labels, val_instance_labels in pbar_val:
                # validation part
                val_step += 1
                val_summary, \
                val_loss, val_binary_loss, val_instance_loss, \
                val_accuracy_figure, val_fn_figure, val_fp_figure = \
                    sess.run([val_merge_summary_op,
                              val_total_loss, val_binary_seg_loss, val_disc_loss,
                              val_accuracy, val_fn, val_fp],
                             feed_dict={val_input_tensor: val_images,
                                        val_binary_label_tensor: val_binary_labels,
                                        val_instance_label_tensor: val_instance_labels}
                             )
                # ============================== 透心凉,心飞扬
                if math.isnan(val_loss) or math.isnan(
                        val_binary_loss) or math.isnan(val_instance_loss):
                    log.error('cost is: {:.5f}'.format(val_loss))
                    log.error('binary cost is: {:.5f}'.format(val_binary_loss))
                    log.error(
                        'instance cost is: {:.5f}'.format(val_instance_loss))
                    return
                # ==============================
                summary_writer.add_summary(summary=val_summary,
                                           global_step=val_step)
                pbar_val.set_description(("val loss: %.4f, accuracy: %.4f") %
                                         (val_loss, val_accuracy_figure))

                val_epoch_loss.append(val_loss)
                val_epoch_binary_loss.append(val_binary_loss)
                val_epoch_instance_loss.append(val_instance_loss)
                val_epoch_accuracy_figure.append(val_accuracy_figure)
                val_epoch_fp_figure.append(val_fp_figure)
                val_epoch_fn_figure.append(val_fn_figure)
            val_cost_time = time.time() - val_t_start
            mean_val_loss = np.mean(val_epoch_loss)
            mean_val_binary_loss = np.mean(val_epoch_binary_loss)
            mean_val_instance_loss = np.mean(val_epoch_instance_loss)
            mean_val_accuracy_figure = np.mean(val_epoch_accuracy_figure)
            mean_val_fp_figure = np.mean(val_epoch_fp_figure)
            mean_val_fn_figure = np.mean(val_epoch_fn_figure)

            # ==============================
            if mean_val_accuracy_figure > max_acc:
                max_acc = mean_val_accuracy_figure
                if save_num < 3:  # 前三次不算
                    max_acc = 0.9
                log.info(
                    'MAX_ACC change to {}'.format(mean_val_accuracy_figure))
                model_save_path_max = ops.join(
                    model_save_dir, 'tusimple_lanenet_{}.ckpt'.format(
                        mean_val_accuracy_figure))
                saver.save(sess=sess,
                           save_path=model_save_path_max,
                           global_step=global_step)
                save_num += 1
            # ==============================

            log.info(
                '=> Epoch: {}, MEAN Val: total_loss= {:6f} binary_seg_loss= {:6f} '
                'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}'
                ' mean_cost_time= {:5f}s '.format(
                    epoch, mean_val_loss, mean_val_binary_loss,
                    mean_val_instance_loss, mean_val_accuracy_figure,
                    mean_val_fp_figure, mean_val_fn_figure, val_cost_time))
            # ---------------------------------------------------------------- #
    return
예제 #4
0
def train_lanenet(dataset_dir,
                  weights_path=None,
                  net_flag='vgg',
                  version_flag='',
                  scratch=False):
    """
    Train LaneNet With One GPU
    :param dataset_dir:
    :param weights_path:
    :param net_flag:
    :param version_flag:
    :param scratch:
    :return:
    """
    train_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='val')

    # ================================================================ #
    #                           Define Network                         #
    # ================================================================ #
    train_net = lanenet.LaneNet(net_flag=net_flag,
                                phase='train',
                                reuse=tf.AUTO_REUSE)
    val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True)
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                       Train Input & Output                       #
    # ================================================================ #
    # set compute graph node for training
    train_images, train_binary_labels, train_instance_labels = train_dataset.inputs(
        CFG.TRAIN.BATCH_SIZE)

    train_compute_ret = train_net.compute_loss(
        input_tensor=train_images,
        binary_label=train_binary_labels,
        instance_label=train_instance_labels,
        name='lanenet_model')
    train_total_loss = train_compute_ret['total_loss']
    train_binary_seg_loss = train_compute_ret['binary_seg_loss']  # 语义分割 loss
    train_disc_loss = train_compute_ret[
        'discriminative_loss']  # embedding loss
    train_pix_embedding = train_compute_ret[
        'instance_seg_logits']  # embedding feature, HxWxN
    train_l2_reg_loss = train_compute_ret['l2_reg_loss']

    train_prediction_logits = train_compute_ret[
        'binary_seg_logits']  # 语义分割结果,HxWx2
    train_prediction_score = tf.nn.softmax(logits=train_prediction_logits)
    train_prediction = tf.argmax(train_prediction_score, axis=-1)  # 语义分割二值图

    train_accuracy = evaluate_model_utils.calculate_model_precision(
        train_compute_ret['binary_seg_logits'], train_binary_labels)
    train_fp = evaluate_model_utils.calculate_model_fp(
        train_compute_ret['binary_seg_logits'], train_binary_labels)
    train_fn = evaluate_model_utils.calculate_model_fn(
        train_compute_ret['binary_seg_logits'], train_binary_labels)
    train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=train_prediction)  # (I - min) * 255 / (max -min), 归一化到0-255
    train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=train_pix_embedding)  # (I - min) * 255 / (max -min), 归一化到0-255
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                          Define Optimizer                        #
    # ================================================================ #
    # set optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    # learning_rate = tf.train.cosine_decay_restarts( # 余弦衰减
    #     learning_rate=CFG.TRAIN.LEARNING_RATE,      # 初始学习率
    #     global_step=global_step,                    # 当前迭代次数
    #     first_decay_steps=CFG.TRAIN.STEPS/3,        # 首次衰减周期
    #     t_mul=2.0,                                  # 随后每次衰减周期倍数
    #     m_mul=1.0,                                  # 随后每次初始学习率倍数
    #     alpha = 0.1,                                # 最小的学习率=alpha*learning_rate
    # )
    learning_rate = tf.train.polynomial_decay(  # 多项式衰减
        learning_rate=CFG.TRAIN.LEARNING_RATE,  # 初始学习率
        global_step=global_step,  # 当前迭代次数
        decay_steps=CFG.TRAIN.STEPS /
        4,  # 在迭代到该次数实际,学习率衰减为 learning_rate * dacay_rate
        end_learning_rate=CFG.TRAIN.LEARNING_RATE / 10,  # 最小的学习率
        power=0.9,
        cycle=True)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate',
                                             tensor=learning_rate)
    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # for batch normalization
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.MomentumOptimizer(
            learning_rate=learning_rate, momentum=CFG.TRAIN.MOMENTUM).minimize(
                loss=train_total_loss,
                var_list=tf.trainable_variables(),
                global_step=global_step)
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                           Train Summary                          #
    # ================================================================ #
    train_cost_scalar = tf.summary.scalar(name='train_cost',
                                          tensor=train_total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=train_accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=train_binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=train_disc_loss)
    train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn)
    train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp)
    train_binary_seg_ret_img = tf.summary.image(
        name='train_binary_seg_ret', tensor=train_binary_seg_ret_for_summary)
    train_embedding_feats_ret_img = tf.summary.image(
        name='train_embedding_feats_ret',
        tensor=train_embedding_ret_for_summary)
    train_merge_summary_op = tf.summary.merge([
        train_accuracy_scalar, train_cost_scalar, train_binary_seg_loss_scalar,
        train_instance_seg_loss_scalar, train_fn_scalar, train_fp_scalar,
        train_binary_seg_ret_img, train_embedding_feats_ret_img,
        learning_rate_scalar
    ])
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                        Val Input & Output                        #
    # ================================================================ #
    # set compute graph node for validation
    val_images, val_binary_labels, val_instance_labels = val_dataset.inputs(
        CFG.TEST.BATCH_SIZE)

    val_compute_ret = val_net.compute_loss(input_tensor=val_images,
                                           binary_label=val_binary_labels,
                                           instance_label=val_instance_labels,
                                           name='lanenet_model')
    val_total_loss = val_compute_ret['total_loss']
    val_binary_seg_loss = val_compute_ret['binary_seg_loss']
    val_disc_loss = val_compute_ret['discriminative_loss']
    val_pix_embedding = val_compute_ret['instance_seg_logits']

    val_prediction_logits = val_compute_ret['binary_seg_logits']
    val_prediction_score = tf.nn.softmax(logits=val_prediction_logits)
    val_prediction = tf.argmax(val_prediction_score, axis=-1)

    val_accuracy = evaluate_model_utils.calculate_model_precision(
        val_compute_ret['binary_seg_logits'], val_binary_labels)
    val_fp = evaluate_model_utils.calculate_model_fp(
        val_compute_ret['binary_seg_logits'], val_binary_labels)
    val_fn = evaluate_model_utils.calculate_model_fn(
        val_compute_ret['binary_seg_logits'], val_binary_labels)
    val_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=val_prediction)
    val_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
        img=val_pix_embedding)
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                            VAL Summary                           #
    # ================================================================ #
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=val_total_loss)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                            tensor=val_accuracy)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss',
                                                   tensor=val_binary_seg_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(
        name='val_instance_seg_loss', tensor=val_disc_loss)
    val_fn_scalar = tf.summary.scalar(name='val_fn', tensor=val_fn)
    val_fp_scalar = tf.summary.scalar(name='val_fp', tensor=val_fp)
    val_binary_seg_ret_img = tf.summary.image(
        name='val_binary_seg_ret', tensor=val_binary_seg_ret_for_summary)
    val_embedding_feats_ret_img = tf.summary.image(
        name='val_embedding_feats_ret', tensor=val_embedding_ret_for_summary)
    val_merge_summary_op = tf.summary.merge([
        val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
        val_instance_seg_loss_scalar, val_fn_scalar, val_fp_scalar,
        val_binary_seg_ret_img, val_embedding_feats_ret_img
    ])
    # ---------------------------------------------------------------- #

    # ================================================================ #
    #                      Config Saver & Session                      #
    # ================================================================ #
    # Set tf model save path
    model_save_dir = 'model/tusimple_lanenet_{:s}_{:s}'.format(
        net_flag, version_flag)
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # ==============================
    if scratch:
        """
        删除 Momentum 的参数, 注意这里保存的 meta 文件也会删了
        tensorflow 在 save model 的时候,如果选择了 global_step 选项,会 global_step 值也保存下来,
        然后 restore 的时候也就会接着这个 global_step 继续训练下去,因此需要去掉
        """
        variables = tf.contrib.framework.get_variables_to_restore()
        variables_to_resotre = [
            v for v in variables if 'Momentum' not in v.name.split('/')[-1]
        ]
        variables_to_resotre = [
            v for v in variables_to_resotre
            if 'global_step' not in v.name.split('/')[-1]
        ]
        restore_saver = tf.train.Saver(variables_to_resotre)
    else:
        restore_saver = tf.train.Saver()
    saver = tf.train.Saver(max_to_keep=10)
    # ==============================

    # Set tf summary save path
    tboard_save_path = 'tboard/tusimple_lanenet_{:s}_{:s}'.format(
        net_flag, version_flag)
    os.makedirs(tboard_save_path, exist_ok=True)

    # Set sess configuration
    # ============================== config GPU
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'
    # ==============================
    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)
    # ---------------------------------------------------------------- #

    # Set the training parameters
    import math
    train_steps = CFG.TRAIN.STEPS
    val_steps = math.ceil(CFG.TRAIN.VAL_SIZE /
                          CFG.TEST.BATCH_SIZE)  # 测试一个 epoch 需要的 batch 数量
    one_epoch2step = math.ceil(CFG.TRAIN.TRAIN_SIZE /
                               CFG.TRAIN.BATCH_SIZE)  # 训练一个 epoch 需要的 batch 数量

    log.info('Global configuration is as follows:')
    log.info(CFG)
    max_acc = 0.9
    save_num = 0
    # ================================================================ #
    #                            Train & Val                           #
    # ================================================================ #
    with sess.as_default():
        # ============================== load pretrain model
        if weights_path is None:
            log.info('Training from scratch')
            sess.run(tf.global_variables_initializer())
        elif net_flag == 'vgg' and weights_path is None:
            load_pretrained_weights(tf.trainable_variables(),
                                    './data/vgg16.npy', sess)
        elif scratch:  # 从头开始训练,类似 Caffe 的 --weights
            sess.run(tf.global_variables_initializer())
            log.info('Restore model from last model checkpoint {:s}, scratch'.
                     format(weights_path))
            try:
                restore_saver.restore(sess=sess, save_path=weights_path)
            except:
                log.info('model maybe is not exist!')
        else:  # 继续训练,类似 Caffe 的 --snapshot
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            try:
                restore_saver.restore(sess=sess, save_path=weights_path)
            except:
                log.info('model maybe is not exist!')
        # ==============================

        train_cost_time_mean = []  # 统计一个 batch 训练耗时
        for step in range(train_steps):
            # ================================================================ #
            #                               Train                              #
            # ================================================================ #
            t_start = time.time()

            _, train_loss, train_accuracy_figure, train_fn_figure, train_fp_figure, \
                lr, train_summary, train_binary_loss, \
                train_instance_loss, train_embeddings, train_binary_seg_imgs, train_gt_imgs, \
                train_binary_gt_labels, train_instance_gt_labels, train_l2_loss = \
                sess.run([optimizer, train_total_loss, train_accuracy, train_fn, train_fp,
                          learning_rate, train_merge_summary_op, train_binary_seg_loss,
                          train_disc_loss, train_pix_embedding, train_prediction,
                          train_images, train_binary_labels, train_instance_labels, train_l2_reg_loss])

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            # ============================== 透心凉,心飞扬
            if math.isnan(train_loss) or math.isnan(
                    train_binary_loss) or math.isnan(train_instance_loss):
                log.error('cost is: {:.5f}'.format(train_loss))
                log.error('binary cost is: {:.5f}'.format(train_binary_loss))
                log.error(
                    'instance cost is: {:.5f}'.format(train_instance_loss))
                return
            # ==============================
            summary_writer.add_summary(summary=train_summary, global_step=step)

            # 每隔 DISPLAY_STEP 次,打印 loss 值
            if step % CFG.TRAIN.DISPLAY_STEP == 0:
                epoch_num = step // one_epoch2step
                log.info(
                    'Epoch: {:d} Step: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} l2_reg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}'
                    ' lr= {:6f} mean_cost_time= {:5f}s '.format(
                        epoch_num + 1, step + 1, train_loss, train_binary_loss,
                        train_instance_loss, train_l2_loss,
                        train_accuracy_figure, train_fp_figure,
                        train_fn_figure, lr, np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()
            # # 每隔 VAL_DISPLAY_STEP 次,保存模型,保存当前 batch 训练结果图片
            # if step % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
            #     saver.save(sess=sess, save_path=model_save_path, global_step=global_step) # global_step 会保存 global_step 信息
            #     record_training_intermediate_result(
            #         gt_images=train_gt_imgs, gt_binary_labels=train_binary_gt_labels,
            #         gt_instance_labels=train_instance_gt_labels, binary_seg_images=train_binary_seg_imgs,
            #         pix_embeddings=train_embeddings
            #     )
            # ---------------------------------------------------------------- #

            # ================================================================ #
            #                                Val                               #
            # ================================================================ #
            # 每隔 VAL_DISPLAY_STEP 次,测试整个验证集
            if step % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                val_t_start = time.time()
                val_cost_time = 0
                mean_val_c = 0.0
                mean_val_binary_loss = 0.0
                mean_val_instance_loss = 0.0
                mean_val_accuracy_figure = 0.0
                mean_val_fp_figure = 0.0
                mean_val_fn_figure = 0.0
                for val_step in range(val_steps):
                    # validation part
                    val_c, val_accuracy_figure, val_fn_figure, val_fp_figure, \
                        val_summary, val_binary_loss, val_instance_loss, \
                        val_embeddings, val_binary_seg_imgs, val_gt_imgs, \
                        val_binary_gt_labels, val_instance_gt_labels = \
                        sess.run([val_total_loss, val_accuracy, val_fn, val_fp,
                                  val_merge_summary_op, val_binary_seg_loss,
                                  val_disc_loss, val_pix_embedding, val_prediction,
                                  val_images, val_binary_labels, val_instance_labels])

                    # ============================== 透心凉,心飞扬
                    if math.isnan(val_c) or math.isnan(
                            val_binary_loss) or math.isnan(val_instance_loss):
                        log.error('cost is: {:.5f}'.format(val_c))
                        log.error(
                            'binary cost is: {:.5f}'.format(val_binary_loss))
                        log.error('instance cost is: {:.5f}'.format(
                            val_instance_loss))
                        return
                    # ==============================

                    # if val_step == 0:
                    #     record_training_intermediate_result(
                    #         gt_images=val_gt_imgs, gt_binary_labels=val_binary_gt_labels,
                    #         gt_instance_labels=val_instance_gt_labels, binary_seg_images=val_binary_seg_imgs,
                    #         pix_embeddings=val_embeddings, flag='val'
                    #     )

                    cost_time = time.time() - val_t_start
                    val_cost_time += cost_time
                    mean_val_c += val_c
                    mean_val_binary_loss += val_binary_loss
                    mean_val_instance_loss += val_instance_loss
                    mean_val_accuracy_figure += val_accuracy_figure
                    mean_val_fp_figure += val_fp_figure
                    mean_val_fn_figure += val_fn_figure
                    summary_writer.add_summary(summary=val_summary,
                                               global_step=step)

                mean_val_c /= val_steps
                mean_val_binary_loss /= val_steps
                mean_val_instance_loss /= val_steps
                mean_val_accuracy_figure /= val_steps
                mean_val_fp_figure /= val_steps
                mean_val_fn_figure /= val_steps

                # ==============================
                if mean_val_accuracy_figure > max_acc:
                    max_acc = mean_val_accuracy_figure
                    if save_num < 3:  # 前三次不算
                        max_acc = 0.9
                    log.info('MAX_ACC change to {}'.format(
                        mean_val_accuracy_figure))
                    model_save_path_max = ops.join(
                        model_save_dir, 'tusimple_lanenet_{}.ckpt'.format(
                            mean_val_accuracy_figure))
                    saver.save(sess=sess,
                               save_path=model_save_path_max,
                               global_step=global_step)
                    save_num += 1
                # ==============================

                log.info(
                    'MEAN Val: total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        mean_val_c, mean_val_binary_loss,
                        mean_val_instance_loss, mean_val_accuracy_figure,
                        mean_val_fp_figure, mean_val_fn_figure, val_cost_time))

            # ---------------------------------------------------------------- #
    return