Пример #1
0
def test_lanenet_batch(image_list, weights_path, batch_size, use_gpu, net_flag='vgg'):
    """

    :param image_list:
    :param weights_path:
    :param batch_size:
    :param use_gpu:
    :param net_flag:
    :return:
    """
    assert ops.exists(image_list), '{:s} not exist'.format(image_list)

    log.info('开始加载数据集列表...')
    test_dataset = lanenet_data_processor.DataSet(image_list, traing=False)

    # ==============================
    gt_label_binary_list = []
    with open(image_list, 'r') as file:
        for _info in file:
            info_tmp = _info.strip(' ').split()
            gt_label_binary_list.append(info_tmp[1])
    # ==============================

    input_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 256, 512, 3], name='input_tensor')
    binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                         shape=[None, 256, 512, 1], name='binary_input_label')
    phase_tensor = tf.constant('test', tf.string)
    net = lanenet.LaneNet(phase=phase_tensor, net_flag=net_flag)
    binary_seg_ret, instance_seg_ret, recall_ret, false_positive, false_negative, precision_ret, accuracy_ret = \
        net.compute_acc(input_tensor=input_tensor, binary_label_tensor=binary_label_tensor, name='lanenet_model')

    saver = tf.train.Saver()
    # ==============================
    # Set sess configuration
    if use_gpu:
        sess_config = tf.ConfigProto(device_count={'GPU': 1})
    else:
        sess_config = tf.ConfigProto(device_count={'GPU': 0})
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'
    # ==============================
    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)
        epoch_nums = int(math.ceil(test_dataset._dataset_size / batch_size))
        mean_accuracy = 0.0
        mean_recall = 0.0
        mean_precision = 0.0
        mean_fp = 0.0
        mean_fn = 0.0
        total_num = 0
        t_start = time.time()
        for epoch in range(epoch_nums):
            gt_imgs, binary_gt_labels, instance_gt_labels = test_dataset.next_batch(batch_size)
            if net_flag == 'vgg':
                image_list_epoch = [tmp / 127.5 - 1.0 for tmp in gt_imgs]
            elif net_flag == 'mobilenet_v2':
                image_list_epoch = [tmp - [103.939, 116.779, 123.68] for tmp in gt_imgs]

            binary_seg_images, instance_seg_images, recall, fp, fn, precision, accuracy = sess.run(
                [binary_seg_ret, instance_seg_ret, recall_ret, false_positive, false_negative, precision_ret, accuracy_ret],
                feed_dict={input_tensor: image_list_epoch, binary_label_tensor: binary_gt_labels})
            # ==============================
            out_dir = 'H:/Other_DataSets/TuSimple/out/'
            dst_binary_image_path = ops.join(out_dir,gt_label_binary_list[epoch])
            root_dir = ops.dirname(ops.abspath(dst_binary_image_path))
            if not os.path.exists(root_dir):
                os.makedirs(root_dir)
            cv2.imwrite(dst_binary_image_path, binary_seg_images[0] * 255)
            # ==============================
            print(recall, fp, fn)
            mean_accuracy += accuracy
            mean_precision += precision
            mean_recall += recall
            mean_fp += fp
            mean_fn += fn
            total_num += len(gt_imgs)
        t_cost = time.time() - t_start
        mean_accuracy = mean_accuracy / epoch_nums
        mean_precision = mean_precision / epoch_nums
        mean_recall = mean_recall / epoch_nums
        mean_fp = mean_fp / epoch_nums
        mean_fn = mean_fn / epoch_nums
        print('测试 {} 张图片,耗时{},{}_recall = {}, precision = {}, accuracy = {}, fp = {}, fn = {}, '.format(
            total_num, t_cost, net_flag, mean_recall, mean_precision, mean_accuracy, mean_fp, mean_fn))

    sess.close()
Пример #2
0
def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset_file = ops.join(dataset_dir, 'train.txt')
    val_dataset_file = ops.join(dataset_dir, 'val.txt')

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    input_tensor = tf.placeholder(
        dtype=tf.float32,
        shape=[None, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3],
        name='input_tensor')
    binary_label_tensor = tf.placeholder(
        dtype=tf.int64,
        shape=[None, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 1],
        name='binary_input_label')
    instance_label_tensor = tf.placeholder(
        dtype=tf.float32,
        shape=[None, CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH],
        name='instance_input_label')
    phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

    # net = lanenet_instance_segmentation.LaneNetInstanceSeg(net_flag=net_flag, phase=phase)
    net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

    # calculate the loss
    compute_ret = net.compute_loss(input_tensor=input_tensor,
                                   binary_label=binary_label_tensor,
                                   instance_label=instance_label_tensor,
                                   name='lanenet_loss')
    total_loss = compute_ret['total_loss']
    binary_seg_loss = compute_ret['binary_seg_loss']
    disc_loss = compute_ret['discriminative_loss']
    pix_embedding = compute_ret['instance_seg_logits']

    # calculate the accuracy
    out_logits = compute_ret['binary_seg_logits']
    out_logits = tf.nn.softmax(logits=out_logits)
    out_logits_out = tf.argmax(out_logits, axis=-1)
    out = tf.argmax(out_logits, axis=-1)
    out = tf.expand_dims(out, axis=-1)
    accuracy = tf.add(binary_label_tensor, -1 * out)
    accuracy = tf.count_nonzero(accuracy, axis=[1, 2, 3])
    accuracy = tf.add(
        tf.constant(1, dtype=tf.float64),
        -1 * tf.divide(accuracy, CFG.TRAIN.IMG_HEIGHT * CFG.TRAIN.IMG_WIDTH))
    accuracy = tf.reduce_mean(accuracy, axis=0)

    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE,
                                               global_step,
                                               5000,
                                               0.96,
                                               staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(
                loss=total_loss,
                var_list=tf.trainable_variables(),
                global_step=global_step)

    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/kitti_lanenet'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'kitti_lanenet_{:s}_{:s}.ckpt'.format(net_flag,
                                                       str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set tf summary
    tboard_save_path = 'tboard/kitti_lanenet/{:s}'.format(net_flag)
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                            tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss',
                                                   tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(
        name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate',
                                             tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge([
        train_accuracy_scalar, train_cost_scalar, learning_rate_scalar,
        train_binary_seg_loss_scalar, train_instance_seg_loss_scalar
    ])
    val_merge_summary_op = tf.summary.merge([
        val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
        val_instance_seg_loss_scalar
    ])

    # Set sess configuration
    sess_config = tf.ConfigProto(device_count={'GPU': 1})
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/lanenet_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        if net_flag == 'vgg':
            pretrained_weights = np.load(
                '/home/baidu/Silly_Project/ICode/baidu/beec/semantic-road-estimation/data/vgg16.npy',
                encoding='latin1').item()

            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        train_cost_time_mean = []
        val_cost_time_mean = []
        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch(
                CFG.TRAIN.BATCH_SIZE)
            gt_imgs = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs
            ]
            gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]
            binary_gt_labels = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_NEAREST)
                for tmp in binary_gt_labels
            ]
            binary_gt_labels = [
                np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels
            ]
            instance_gt_labels = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_NEAREST)
                for tmp in instance_gt_labels
            ]
            phase_train = 'train'

            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: phase_train})

            if math.isnan(c) or math.isnan(binary_loss) or math.isnan(
                    instance_loss):
                log.error('cost is: {:.5f}'.format(c))
                log.error('binary cost is: {:.5f}'.format(binary_loss))
                log.error('instance cost is: {:.5f}'.format(instance_loss))
                cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('nan_instance_label.png', instance_gt_labels[0])
                cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255)
                cv2.imwrite('nan_embedding.png', embedding[0])
                return
            if epoch % 100 == 0:
                cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255)
                cv2.imwrite('instance_label.png', instance_gt_labels[0])
                cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255)
                cv2.imwrite('embedding.png', embedding[0])

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \
                = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
            gt_imgs_val = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_LINEAR)
                for tmp in gt_imgs_val
            ]
            gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
            binary_gt_labels_val = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp) for tmp in binary_gt_labels_val
            ]
            binary_gt_labels_val = [
                np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val
            ]
            instance_gt_labels_val = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_NEAREST)
                for tmp in instance_gt_labels_val
            ]
            phase_val = 'test'

            t_start_val = time.time()
            c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \
                sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss],
                         feed_dict={input_tensor: gt_imgs_val,
                                    binary_label_tensor: binary_gt_labels_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    phase: phase_val})

            if epoch % 100 == 0:
                cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN)

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        epoch + 1, c, binary_loss, instance_loss,
                        train_accuracy, np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                log.info(
                    'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} '
                    'mean_cost_time= {:5f}s '.format(
                        epoch + 1, c_val, val_binary_seg_loss,
                        val_instance_seg_loss, val_accuracy,
                        np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    sess.close()

    return
def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset_file = ops.join(dataset_dir, 'train_gt.txt')
    val_dataset_file = ops.join(dataset_dir, 'val_gt.txt')

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[
                                      CFG.TRAIN.BATCH_SIZE,
                                      CFG.TRAIN.IMG_HEIGHT,
                                      CFG.TRAIN.IMG_WIDTH, 3
                                  ],
                                  name='input_tensor')
    instance_label_tensor = tf.placeholder(dtype=tf.int64,
                                           shape=[
                                               CFG.TRAIN.BATCH_SIZE,
                                               CFG.TRAIN.IMG_HEIGHT,
                                               CFG.TRAIN.IMG_WIDTH
                                           ],
                                           name='instance_input_label')
    existence_label_tensor = tf.placeholder(dtype=tf.float32,
                                            shape=[CFG.TRAIN.BATCH_SIZE, 4],
                                            name='existence_input_label')
    phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

    net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

    # calculate the loss
    compute_ret = net.compute_loss(input_tensor=input_tensor,
                                   binary_label=instance_label_tensor,
                                   existence_label=existence_label_tensor,
                                   name='lanenet_loss')
    total_loss = compute_ret['total_loss']
    instance_loss = compute_ret['instance_seg_loss']
    existence_loss = compute_ret['existence_pre_loss']
    existence_logits = compute_ret['existence_logits']

    # calculate the accuracy
    out_logits = compute_ret['instance_seg_logits']
    out_logits_ref = out_logits
    out_logits = tf.nn.softmax(logits=out_logits)
    out_logits_out = tf.argmax(out_logits, axis=-1)  # 8 x 288 x 800

    pred_0 = tf.count_nonzero(
        tf.multiply(tf.cast(tf.equal(instance_label_tensor, 0), tf.int64),
                    tf.cast(tf.equal(out_logits_out, 0), tf.int64)))

    pred_1 = tf.count_nonzero(
        tf.multiply(tf.cast(tf.equal(instance_label_tensor, 1), tf.int64),
                    tf.cast(tf.equal(out_logits_out, 1), tf.int64)))
    pred_2 = tf.count_nonzero(
        tf.multiply(tf.cast(tf.equal(instance_label_tensor, 2), tf.int64),
                    tf.cast(tf.equal(out_logits_out, 2), tf.int64)))
    pred_3 = tf.count_nonzero(
        tf.multiply(tf.cast(tf.equal(instance_label_tensor, 3), tf.int64),
                    tf.cast(tf.equal(out_logits_out, 3), tf.int64)))
    pred_4 = tf.count_nonzero(
        tf.multiply(tf.cast(tf.equal(instance_label_tensor, 4), tf.int64),
                    tf.cast(tf.equal(out_logits_out, 4), tf.int64)))
    gt_all = tf.count_nonzero(
        tf.cast(tf.greater(instance_label_tensor, 0), tf.int64))
    gt_back = tf.count_nonzero(
        tf.cast(tf.equal(instance_label_tensor, 0), tf.int64))

    pred_all = tf.add(tf.add(tf.add(pred_1, pred_2), pred_3), pred_4)

    accuracy = tf.divide(pred_all, gt_all)
    accuracy_back = tf.divide(pred_0, gt_back)

    # Compute mIoU of Lanes
    overlap_1 = pred_1
    union_1 = tf.add(
        tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 1),
                                 tf.int64)),
        tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 1), tf.int64)))
    union_1 = tf.subtract(union_1, overlap_1)
    IoU_1 = tf.divide(overlap_1, union_1)

    overlap_2 = pred_2
    union_2 = tf.add(
        tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 2),
                                 tf.int64)),
        tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 2), tf.int64)))
    union_2 = tf.subtract(union_2, overlap_2)
    IoU_2 = tf.divide(overlap_2, union_2)

    overlap_3 = pred_3
    union_3 = tf.add(
        tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 3),
                                 tf.int64)),
        tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 3), tf.int64)))
    union_3 = tf.subtract(union_3, overlap_3)
    IoU_3 = tf.divide(overlap_3, union_3)

    overlap_4 = pred_4
    union_4 = tf.add(
        tf.count_nonzero(tf.cast(tf.equal(instance_label_tensor, 4),
                                 tf.int64)),
        tf.count_nonzero(tf.cast(tf.equal(out_logits_out, 4), tf.int64)))
    union_4 = tf.subtract(union_4, overlap_4)
    IoU_4 = tf.divide(overlap_4, union_4)

    IoU = tf.reduce_mean(tf.stack([IoU_1, IoU_2, IoU_3, IoU_4]))

    global_step = tf.Variable(0, trainable=False)

    learning_rate = tf.train.polynomial_decay(CFG.TRAIN.LEARNING_RATE,
                                              global_step,
                                              90100,
                                              power=0.9)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.MomentumOptimizer(
            learning_rate=learning_rate,
            momentum=0.9).minimize(loss=total_loss,
                                   var_list=tf.trainable_variables(),
                                   global_step=global_step)

    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/culane_lanenet/culane_scnn'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'culane_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(device_count={'GPU':
                                               4})  # device_count={'GPU': 1}
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        if net_flag == 'vgg' and weights_path is None:
            pretrained_weights = np.load('./data/vgg16.npy',
                                         encoding='latin1').item()

            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        train_cost_time_mean = []
        train_instance_loss_mean = []
        train_existence_loss_mean = []
        train_accuracy_mean = []
        train_accuracy_back_mean = []

        val_cost_time_mean = []
        val_instance_loss_mean = []
        val_existence_loss_mean = []
        val_accuracy_mean = []
        val_accuracy_back_mean = []
        val_IoU_mean = []

        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            gt_imgs, instance_gt_labels, existence_gt_labels = train_dataset.next_batch(
                CFG.TRAIN.BATCH_SIZE)

            gt_imgs = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_CUBIC) for tmp in gt_imgs
            ]
            gt_imgs = [(tmp - VGG_MEAN) for tmp in gt_imgs]

            instance_gt_labels = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_NEAREST)
                for tmp in instance_gt_labels
            ]

            phase_train = 'train'

            _, c, train_accuracy, train_accuracy_back, train_instance_loss, train_existence_loss, binary_seg_img = \
                sess.run([optimizer, total_loss,
                          accuracy, accuracy_back,
                          instance_loss,
                          existence_loss,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    instance_label_tensor: instance_gt_labels,
                                    existence_label_tensor: existence_gt_labels,
                                    phase: phase_train})

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            train_instance_loss_mean.append(train_instance_loss)
            train_existence_loss_mean.append(train_existence_loss)
            train_accuracy_mean.append(train_accuracy)
            train_accuracy_back_mean.append(train_accuracy_back)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                print(
                    'Epoch: {:d} loss_ins= {:6f} ({:6f}) loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) accuracy_back= {:6f} ({:6f})'
                    ' mean_time= {:5f}s '.format(
                        epoch + 1, train_instance_loss,
                        np.mean(train_instance_loss_mean),
                        train_existence_loss,
                        np.mean(train_existence_loss_mean), train_accuracy,
                        np.mean(train_accuracy_mean), train_accuracy_back,
                        np.mean(train_accuracy_back_mean),
                        np.mean(train_cost_time_mean)))

            if epoch % 500 == 0:
                train_cost_time_mean.clear()
                train_instance_loss_mean.clear()
                train_existence_loss_mean.clear()
                train_accuracy_mean.clear()
                train_accuracy_back_mean.clear()

            if epoch % 1000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)

            if epoch % 10000 != 0 or epoch == 0:
                continue

            for epoch_val in range(int(9675 / 8.0)):

                # validation part
                gt_imgs_val, instance_gt_labels_val, existence_gt_labels_val \
                  = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
                gt_imgs_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_CUBIC)
                    for tmp in gt_imgs_val
                ]
                gt_imgs_val = [(tmp - VGG_MEAN) for tmp in gt_imgs_val]

                instance_gt_labels_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_NEAREST)
                    for tmp in instance_gt_labels_val
                ]
                phase_val = 'test'

                t_start_val = time.time()
                c_val, val_accuracy, val_accuracy_back, val_IoU, val_instance_loss, val_existence_loss = \
                  sess.run([total_loss, accuracy, accuracy_back, IoU, instance_loss, existence_loss],
                             feed_dict={input_tensor: gt_imgs_val,
                                        instance_label_tensor: instance_gt_labels_val,
                                        existence_label_tensor: existence_gt_labels_val,
                                        phase: phase_val})

                cost_time_val = time.time() - t_start_val
                val_cost_time_mean.append(cost_time_val)
                val_instance_loss_mean.append(val_instance_loss)
                val_existence_loss_mean.append(val_existence_loss)
                val_accuracy_mean.append(val_accuracy)
                val_accuracy_back_mean.append(val_accuracy_back)
                val_IoU_mean.append(val_IoU)

                if epoch_val % 1 == 0:
                    print(
                        'Epoch_Val: {:d} loss_ins= {:6f} ({:6f}) '
                        'loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) accuracy_back= {:6f} ({:6f}) mIoU= {:6f} ({:6f})'
                        'mean_time= {:5f}s '.format(
                            epoch_val + 1, val_instance_loss,
                            np.mean(val_instance_loss_mean),
                            val_existence_loss,
                            np.mean(val_existence_loss_mean), val_accuracy,
                            np.mean(val_accuracy_mean), val_accuracy_back,
                            np.mean(val_accuracy_back_mean), val_IoU,
                            np.mean(val_IoU_mean),
                            np.mean(val_cost_time_mean)))

            val_cost_time_mean.clear()
            val_instance_loss_mean.clear()
            val_existence_loss_mean.clear()
            val_accuracy_mean.clear()
            val_accuracy_back_mean.clear()
            val_IoU_mean.clear()

    sess.close()

    return
Пример #4
0
def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset_file = ops.join(dataset_dir, 'train_gt.txt')
    val_dataset_file = ops.join(dataset_dir, 'val_gt.txt')

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[
                                      CFG.TRAIN.BATCH_SIZE,
                                      CFG.TRAIN.IMG_HEIGHT,
                                      CFG.TRAIN.IMG_WIDTH, 3
                                  ],
                                  name='input_tensor')
    instance_label_tensor = tf.placeholder(dtype=tf.int64,
                                           shape=[
                                               CFG.TRAIN.BATCH_SIZE,
                                               CFG.TRAIN.IMG_HEIGHT,
                                               CFG.TRAIN.IMG_WIDTH
                                           ],
                                           name='instance_input_label')
    existence_label_tensor = tf.placeholder(dtype=tf.float32,
                                            shape=[CFG.TRAIN.BATCH_SIZE, 4],
                                            name='existence_input_label')
    phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

    net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

    # calculate the loss
    compute_ret = net.compute_loss(input_tensor=input_tensor,
                                   binary_label=instance_label_tensor,
                                   existence_label=existence_label_tensor,
                                   name='lanenet_loss')
    total_loss = compute_ret['total_loss']
    instance_loss = compute_ret['instance_seg_loss']
    existence_loss = compute_ret['existence_pre_loss']
    existence_logits = compute_ret['existence_logits']

    # calculate the accuracy
    out_logits = compute_ret['instance_seg_logits']
    out_logits = tf.nn.softmax(logits=out_logits)
    out_logits_out = tf.argmax(out_logits, axis=-1)
    out = tf.argmax(out_logits, axis=-1)
    out = tf.expand_dims(out, axis=-1)

    idx = tf.where(tf.equal(instance_label_tensor, 1))
    pix_cls_ret = tf.gather_nd(out, idx)
    accuracy = tf.count_nonzero(pix_cls_ret)
    accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE,
                                               global_step,
                                               5000,
                                               0.96,
                                               staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(
                loss=total_loss,
                var_list=tf.trainable_variables(),
                global_step=global_step)

    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/culane_lanenet/culane_scnn'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'culane_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(device_count={'GPU':
                                               4})  # device_count={'GPU': 1}
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/lanenet_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        if net_flag == 'vgg' and weights_path is None:
            pretrained_weights = np.load('./data/vgg16.npy',
                                         encoding='latin1').item()

            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        train_cost_time_mean = []
        train_instance_loss_mean = []
        train_existence_loss_mean = []
        train_accuracy_mean = []

        val_cost_time_mean = []
        val_instance_loss_mean = []
        val_existence_loss_mean = []
        val_accuracy_mean = []

        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            gt_imgs, instance_gt_labels, existence_gt_labels = train_dataset.next_batch(
                CFG.TRAIN.BATCH_SIZE)
            gt_imgs = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_LINEAR) for tmp in gt_imgs
            ]
            gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]

            instance_gt_labels = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_NEAREST)
                for tmp in instance_gt_labels
            ]

            phase_train = 'train'

            _, c, train_accuracy, train_instance_loss, train_existence_loss, binary_seg_img = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          instance_loss,
                          existence_loss,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    instance_label_tensor: instance_gt_labels,
                                    existence_label_tensor: existence_gt_labels,
                                    phase: phase_train})

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            train_instance_loss_mean.append(train_instance_loss)
            train_existence_loss_mean.append(train_existence_loss)
            train_accuracy_mean.append(train_accuracy)

            # validation part
            gt_imgs_val, instance_gt_labels_val, existence_gt_labels_val \
                = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
            gt_imgs_val = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_LINEAR)
                for tmp in gt_imgs_val
            ]
            gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
            instance_gt_labels_val = [
                cv2.resize(tmp,
                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                           dst=tmp,
                           interpolation=cv2.INTER_NEAREST)
                for tmp in instance_gt_labels_val
            ]
            phase_val = 'test'

            t_start_val = time.time()
            c_val, val_accuracy, val_instance_loss, val_existence_loss = \
                sess.run([total_loss, accuracy, instance_loss, existence_loss],
                         feed_dict={input_tensor: gt_imgs_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    existence_label_tensor: existence_gt_labels_val,
                                    phase: phase_val})

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)
            val_instance_loss_mean.append(val_instance_loss)
            val_existence_loss_mean.append(val_existence_loss)
            val_accuracy_mean.append(val_accuracy)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                print(
                    'Epoch: {:d} loss_ins= {:6f} ({:6f}) loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f})'
                    ' mean_time= {:5f}s '.format(
                        epoch + 1, train_instance_loss,
                        np.mean(train_instance_loss_mean),
                        train_existence_loss,
                        np.mean(train_existence_loss_mean), train_accuracy,
                        np.mean(train_accuracy_mean),
                        np.mean(train_cost_time_mean)))  # log.info

            if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                print('Epoch_Val: {:d} loss_ins= {:6f} ({:6f}) '
                      'loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f})'
                      'mean_time= {:5f}s '.format(
                          epoch + 1, val_instance_loss,
                          np.mean(val_instance_loss_mean), val_existence_loss,
                          np.mean(val_existence_loss_mean), val_accuracy,
                          np.mean(val_accuracy_mean),
                          np.mean(val_cost_time_mean)))

            if epoch % 500 == 0:
                train_cost_time_mean.clear()
                train_instance_loss_mean.clear()
                train_existence_loss_mean.clear()
                train_accuracy_mean.clear()

                val_cost_time_mean.clear()
                val_instance_loss_mean.clear()
                val_existence_loss_mean.clear()
                val_accuracy_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    sess.close()

    return
Пример #5
0
def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
    train_dataset_file = ops.join(dataset_dir, 'train_gt.txt')
    val_dataset_file = ops.join(dataset_dir, 'val_gt.txt')

    assert ops.exists(train_dataset_file)

    phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    net = lanenet_merge_model.LaneNet()

    tower_grads = []

    global_step = tf.Variable(0, trainable=False)

    learning_rate = tf.train.polynomial_decay(CFG.TRAIN.LEARNING_RATE,
                                              global_step,
                                              CFG.TRAIN.EPOCHS,
                                              power=0.9)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.9)
    img, label_instance, label_existence = train_dataset.next_batch(
        CFG.TRAIN.BATCH_SIZE)
    batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
        [img, label_instance, label_existence],
        capacity=2 * CFG.TRAIN.GPU_NUM,
        num_threads=CFG.TRAIN.CPU_NUM)

    val_img, val_label_instance, val_label_existence = val_dataset.next_batch(
        CFG.TRAIN.BATCH_SIZE)
    val_batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
        [val_img, val_label_instance, val_label_existence],
        capacity=2 * CFG.TRAIN.GPU_NUM,
        num_threads=CFG.TRAIN.CPU_NUM)
    with tf.variable_scope(tf.get_variable_scope()):
        for i in range(CFG.TRAIN.GPU_NUM):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('tower_%d' % i):
                    total_loss, instance_loss, existence_loss, accuracy, accuracy_back, _, out_logits_out, \
                        grad = forward(batch_queue, net, phase, optimizer)
                    tower_grads.append(grad)
                    val_op_total_loss, val_op_instance_loss, val_op_existence_loss, val_op_accuracy, \
                        val_op_accuracy_back, val_op_IoU, _, _ = forward(val_batch_queue, net, phase)

    grads = average_gradients(tower_grads)

    train_op = optimizer.apply_gradients(grads, global_step=global_step)

    train_cost_time_mean = []
    train_instance_loss_mean = []
    train_existence_loss_mean = []
    train_accuracy_mean = []
    train_accuracy_back_mean = []

    saver = tf.train.Saver()
    model_save_dir = 'model/culane_lanenet/culane_scnn'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'culane_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM},
                                 allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    with tf.Session(config=sess_config) as sess:
        with sess.as_default():

            if weights_path is None:
                log.info('Training from scratch')
                init = tf.global_variables_initializer()
                sess.run(init)
            else:
                log.info(
                    'Restore model from last model checkpoint {:s}'.format(
                        weights_path))
                saver.restore(sess=sess, save_path=weights_path)

            # 加载预训练参数
            if net_flag == 'vgg' and weights_path is None:
                pretrained_weights = np.load('./data/vgg16.npy',
                                             encoding='latin1').item()

                for vv in tf.trainable_variables():
                    weights = vv.name.split('/')
                    if len(weights) >= 3 and weights[-3] in pretrained_weights:
                        try:
                            weights_key = weights[-3]
                            weights = pretrained_weights[weights_key][0]
                            _op = tf.assign(vv, weights)
                            sess.run(_op)
                        except Exception as e:
                            continue
        tf.train.start_queue_runners(sess=sess)
        for epoch in range(CFG.TRAIN.EPOCHS):
            t_start = time.time()

            _, c, train_accuracy, train_accuracy_back, train_instance_loss, train_existence_loss, binary_seg_img = \
                sess.run([train_op, total_loss, accuracy, accuracy_back, instance_loss, existence_loss, out_logits_out],
                         feed_dict={phase: 'train'})

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            train_instance_loss_mean.append(train_instance_loss)
            train_existence_loss_mean.append(train_existence_loss)
            train_accuracy_mean.append(train_accuracy)
            train_accuracy_back_mean.append(train_accuracy_back)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                print(
                    'Epoch: {:d} loss_ins= {:6f} ({:6f}) loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) '
                    'accuracy_back= {:6f} ({:6f}) mean_time= {:5f}s '.format(
                        epoch + 1, train_instance_loss,
                        np.mean(train_instance_loss_mean),
                        train_existence_loss,
                        np.mean(train_existence_loss_mean), train_accuracy,
                        np.mean(train_accuracy_mean), train_accuracy_back,
                        np.mean(train_accuracy_back_mean),
                        np.mean(train_cost_time_mean)))

            if epoch % 500 == 0:
                train_cost_time_mean.clear()
                train_instance_loss_mean.clear()
                train_existence_loss_mean.clear()
                train_accuracy_mean.clear()
                train_accuracy_back_mean.clear()

            if epoch % 1000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)

            if epoch % 10000 != 0 or epoch == 0:
                continue

            val_cost_time_mean = []
            val_instance_loss_mean = []
            val_existence_loss_mean = []
            val_accuracy_mean = []
            val_accuracy_back_mean = []
            val_IoU_mean = []

            for epoch_val in range(
                    int(
                        len(val_dataset) / CFG.TRAIN.VAL_BATCH_SIZE /
                        CFG.TRAIN.GPU_NUM)):
                t_start_val = time.time()
                c_val, val_accuracy, val_accuracy_back, val_IoU, val_instance_loss, val_existence_loss = \
                    sess.run(
                        [val_op_total_loss, val_op_accuracy, val_op_accuracy_back,
                         val_op_IoU, val_op_instance_loss, val_op_existence_loss],
                        feed_dict={phase: 'test'})

                cost_time_val = time.time() - t_start_val
                val_cost_time_mean.append(cost_time_val)
                val_instance_loss_mean.append(val_instance_loss)
                val_existence_loss_mean.append(val_existence_loss)
                val_accuracy_mean.append(val_accuracy)
                val_accuracy_back_mean.append(val_accuracy_back)
                val_IoU_mean.append(val_IoU)

                if epoch_val % 1 == 0:
                    print(
                        'Epoch_Val: {:d} loss_ins= {:6f} ({:6f}) '
                        'loss_ext= {:6f} ({:6f}) accuracy= {:6f} ({:6f}) accuracy_back= {:6f} ({:6f}) '
                        'mIoU= {:6f} ({:6f}) mean_time= {:5f}s '.format(
                            epoch_val + 1, val_instance_loss,
                            np.mean(val_instance_loss_mean),
                            val_existence_loss,
                            np.mean(val_existence_loss_mean), val_accuracy,
                            np.mean(val_accuracy_mean), val_accuracy_back,
                            np.mean(val_accuracy_back_mean), val_IoU,
                            np.mean(val_IoU_mean),
                            np.mean(val_cost_time_mean)))

            val_cost_time_mean.clear()
            val_instance_loss_mean.clear()
            val_existence_loss_mean.clear()
            val_accuracy_mean.clear()
            val_accuracy_back_mean.clear()
            val_IoU_mean.clear()
def train_net(
        dataset_dir,
        weights_path=None,
        net_flag='vgg',
        save_dir="./logs/train/lanenet",
        tboard_save_path="./tboard/lanenet",
        ignore_labels_path="/media/remus/datasets/AVMSnapshots/AVM/ignore_labels.png",
        my_checkpoint="true"):
    """

    :param save_dir:
    :param ignore_labels_path:
    :param tboard_save_path:
    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset_file = ops.join(dataset_dir, 'train.txt')
    val_dataset_file = ops.join(dataset_dir, 'val.txt')

    assert ops.exists(train_dataset_file)
    # tf.enable_eager_execution()

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    with tf.device('/gpu:1'):
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          CFG.TRAIN.BATCH_SIZE,
                                          CFG.TRAIN.IMG_HEIGHT,
                                          CFG.TRAIN.IMG_WIDTH, 3
                                      ],
                                      name='input_tensor')
        binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                             shape=[
                                                 CFG.TRAIN.BATCH_SIZE,
                                                 CFG.TRAIN.IMG_HEIGHT,
                                                 CFG.TRAIN.IMG_WIDTH, 1
                                             ],
                                             name='binary_input_label')
        instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                               shape=[
                                                   CFG.TRAIN.BATCH_SIZE,
                                                   CFG.TRAIN.IMG_HEIGHT,
                                                   CFG.TRAIN.IMG_WIDTH
                                               ],
                                               name='instance_input_label')
        phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

        net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

        # calculate the loss
        compute_ret = net.compute_loss(input_tensor=input_tensor,
                                       binary_label=binary_label_tensor,
                                       instance_label=instance_label_tensor,
                                       ignore_label=255,
                                       name='lanenet_model')
        total_loss = compute_ret['total_loss']
        binary_seg_loss = compute_ret['binary_seg_loss']
        disc_loss = compute_ret['discriminative_loss']
        pix_embedding = compute_ret['instance_seg_logits']

        # calculate the accuracy
        out_logits = compute_ret['binary_seg_logits']
        out_logits = tf.nn.softmax(logits=out_logits)
        out_logits_out = tf.argmax(out_logits, axis=-1)
        out = tf.argmax(out_logits, axis=-1)
        out = tf.expand_dims(out, axis=-1)

        idx = tf.where(tf.equal(binary_label_tensor, 1))
        pix_cls_ret = tf.gather_nd(out, idx)
        accuracy = tf.count_nonzero(pix_cls_ret)
        accuracy = tf.divide(accuracy,
                             tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE,
                                                   global_step,
                                                   CFG.TRAIN.LR_DECAY_STEPS,
                                                   CFG.TRAIN.LR_DECAY_RATE,
                                                   staircase=True)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=0.9).minimize(loss=total_loss,
                                       var_list=tf.trainable_variables(),
                                       global_step=global_step)
            # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    # Set tf saver

    if my_checkpoint == "true":
        init_saver = tf.train.Saver()

    else:
        from correct_path_saver import restore_from_classification_checkpoint_fn, get_variables_available_in_checkpoint
        if weights_path is not None:
            # var_map = restore_from_classification_checkpoint_fn("lanenet_model/inference")
            available_var_map = (get_variables_available_in_checkpoint(
                tf.global_variables(), weights_path,
                include_global_step=False))

            init_saver = tf.train.Saver(available_var_map)
        else:
            init_saver = tf.train.Saver()

    if not ops.exists(save_dir):
        os.makedirs(save_dir)

    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = '{:s}_lanenet_{:s}.ckpt'.format(net_flag,
                                                 str(train_start_time))
    model_save_path = ops.join(save_dir, model_name)

    # Set tf summary
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                            tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss',
                                                   tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(
        name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate',
                                             tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge([
        train_accuracy_scalar, train_cost_scalar, learning_rate_scalar,
        train_binary_seg_loss_scalar, train_instance_seg_loss_scalar
    ])
    val_merge_summary_op = tf.summary.merge([
        val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
        val_instance_seg_loss_scalar
    ])

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'
    # sess_config.device_count = {'GPU': 0}

    sess = tf.Session(config=sess_config)
    # sess = tf_debug.TensorBoardDebugWrapperSession(sess=sess,
    #                                                grpc_debug_server_addresses="remusm-pc:7000",
    #                                                send_traceback_and_source_code=False)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    tf.logging.info('Global configuration is as follows:')
    tf.logging.info(CFG)

    iter_saver = tf.train.Saver(max_to_keep=10)
    best_saver = tf.train.Saver(max_to_keep=3)

    with sess.as_default():

        sess.run(tf.global_variables_initializer())

        tf.train.write_graph(graph_or_graph_def=sess.graph,
                             logdir='',
                             name='{:s}/lanenet_model.pb'.format(save_dir))

        if weights_path is None:
            tf.logging.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            tf.logging.info(
                'Restore model from last model checkpoint {:s}'.format(
                    weights_path))
            init_saver.restore(sess=sess, save_path=weights_path)

            assign_op = global_step.assign(0)
            sess.run(assign_op)

        # 加载预训练参数
        if net_flag == 'vgg' and weights_path is None:
            pretrained_weights = np.load('./data/vgg16.npy',
                                         encoding='latin1').item()

            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        train_cost_time_mean = []
        val_cost_time_mean = []
        ignore_label_mask = cv2.imread(ignore_labels_path)
        last_c = 100000

        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            with tf.device('/cpu:0'):
                gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch(
                    CFG.TRAIN.BATCH_SIZE,
                    ignore_label_mask=ignore_label_mask,
                    ignore_label=255)

                # gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]
                gt_imgs = [tmp / 128.0 - 1.0 for tmp in gt_imgs]

                binary_gt_labels = [
                    np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels
                ]

            phase_train = 'train'

            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img, g_step = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out,
                          global_step],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: phase_train})
            # if epoch % 10 == 0:
            # tf.logging.info("Epoch {}."
            #     "Total loss: {}. Train acc: {}."
            #     " Binary loss: {}. Instance loss: {}".format(epoch, c, train_accuracy,
            #                                                  binary_loss, instance_loss))

            if math.isnan(c) or math.isnan(binary_loss) or math.isnan(
                    instance_loss):
                tf.logging.error('cost is: {:.5f}'.format(c))
                tf.logging.error('binary cost is: {:.5f}'.format(binary_loss))
                tf.logging.error(
                    'instance cost is: {:.5f}'.format(instance_loss))
                # cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('nan_image.png', (gt_imgs[0] + 1.0) * 128)
                cv2.imwrite('nan_instance_label.png', instance_gt_labels[0])
                cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255)
                return

            if epoch % 100 == 0:
                # cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('image.png', (gt_imgs[0] + 1.0) * 128)
                cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255)
                cv2.imwrite('instance_label.png', instance_gt_labels[0])
                cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255)

                for i in range(4):
                    embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i])
                embedding_image = np.array(embedding[0], np.uint8)
                cv2.imwrite('embedding.png', embedding_image[:, :, :-1])

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            with tf.device('/cpu:0'):
                gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \
                    = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE, ignore_label_mask=ignore_label_mask)

                # gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
                gt_imgs_val = [tmp / 128.0 - 1.0 for tmp in gt_imgs_val]

                binary_gt_labels_val = [
                    np.expand_dims(tmp, axis=-1)
                    for tmp in binary_gt_labels_val
                ]
            phase_val = 'test'

            t_start_val = time.time()
            c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \
                sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss],
                         feed_dict={input_tensor: gt_imgs_val,
                                    binary_label_tensor: binary_gt_labels_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    phase: phase_val})

            if epoch % 100 == 0:
                # cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN)
                cv2.imwrite('test_image.png', (gt_imgs_val[0] + 1.0) * 128)

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                tf.logging.info(
                    'Step: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        epoch + 1, c, binary_loss, instance_loss,
                        train_accuracy, np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                tf.logging.info(
                    'Step_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} '
                    'mean_cost_time= {:5f}s '.format(
                        epoch + 1, c_val, val_binary_seg_loss,
                        val_instance_seg_loss, val_accuracy,
                        np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 2000 == 0:
                iter_saver.save(sess=sess,
                                save_path=model_save_path,
                                global_step=epoch)

                if c < last_c:
                    last_c = c
                    save_dir_best = save_dir + "/best"
                    if not ops.exists(save_dir_best):
                        os.makedirs(save_dir_best)
                    best_model_save_path = ops.join(save_dir_best, model_name)

                    best_saver.save(sess=sess,
                                    save_path=best_model_save_path,
                                    global_step=epoch)

    sess.close()

    return
def train_net(dataset_dir, weights_path=None, net_flag='shuffle'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset_file = ops.join(dataset_dir, 'train.txt')
    val_dataset_file = ops.join(dataset_dir, 'val.txt')

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    with tf.device('/gpu:0'):
        print("gpu enableing...")
        #训练灰度图像shape要改成1吧
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          CFG.TRAIN.BATCH_SIZE,
                                          CFG.TRAIN.IMG_HEIGHT,
                                          CFG.TRAIN.IMG_WIDTH, 3
                                      ],
                                      name='input_tensor')
        binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                             shape=[
                                                 CFG.TRAIN.BATCH_SIZE,
                                                 CFG.TRAIN.IMG_HEIGHT,
                                                 CFG.TRAIN.IMG_WIDTH, 1
                                             ],
                                             name='binary_input_label')
        instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                               shape=[
                                                   CFG.TRAIN.BATCH_SIZE,
                                                   CFG.TRAIN.IMG_HEIGHT,
                                                   CFG.TRAIN.IMG_WIDTH
                                               ],
                                               name='instance_input_label')
        phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

        net = shufflenet_merge_model.Shuffle_LaneNet(net_flag=net_flag,
                                                     phase=phase)

        # calculate the loss
        compute_ret = net.compute_loss(input_tensor=input_tensor,
                                       binary_label=binary_label_tensor,
                                       instance_label=instance_label_tensor,
                                       name='lanenet_model')
        total_loss = compute_ret['total_loss']
        binary_seg_loss = compute_ret['binary_seg_loss']
        disc_loss = compute_ret['discriminative_loss']
        pix_embedding = compute_ret['instance_seg_logits']

        # calculate the accuracy
        out_logits = compute_ret['binary_seg_logits']
        out_logits = tf.nn.softmax(logits=out_logits)
        out_logits_out = tf.argmax(out_logits, axis=-1)
        out = tf.argmax(out_logits, axis=-1)
        out = tf.expand_dims(out, axis=-1)

        idx = tf.where(tf.equal(binary_label_tensor, 1))
        pix_cls_ret = tf.gather_nd(out, idx)
        accuracy = tf.count_nonzero(pix_cls_ret)
        accuracy = tf.divide(accuracy,
                             tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE,
                                                   global_step,
                                                   100000,
                                                   0.1,
                                                   staircase=True)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=0.9).minimize(loss=total_loss,
                                       var_list=tf.trainable_variables(),
                                       global_step=global_step)

    # Set tf saver
    saver = tf.train.Saver()
    # 确定权重存储路径
    if net_flag == 'vgg':
        model_save_dir = 'model/vgg/dvs'
    if net_flag == 'shuffle':
        model_save_dir = 'model/shufflenet/dvs'

    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'dvs_lanenet_{:s}_{:s}.ckpt'.format(net_flag,
                                                     str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set tf summary
    tboard_save_path = 'tboard/tusimple_lanenet/{:s}'.format(net_flag)
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                            tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss',
                                                   tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(
        name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate',
                                             tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge([
        train_accuracy_scalar, train_cost_scalar, learning_rate_scalar,
        train_binary_seg_loss_scalar, train_instance_seg_loss_scalar
    ])
    val_merge_summary_op = tf.summary.merge([
        val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
        val_instance_seg_loss_scalar
    ])

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/lanenet_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        if net_flag == 'vgg' and weights_path is None:
            print('test')
            pretrained_weights = np.load('./data/vgg16.npy',
                                         encoding='latin1').item()

            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        if net_flag == 'shuffle' and weights_path is None:
            variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            try:
                print("Loading ImageNet pretrained weights...")
                dict = load_obj('./data/shufflenet_weights.pkl')
                run_list = []
                for variable in variables:
                    for key, value in dict.items():
                        # Adding ':' means that we are interested in the variable itself and not the variable parameters
                        # that are used in adaptive optimizers
                        if key + ":" in variable.name:
                            run_list.append(tf.assign(variable, value))

                sess.run(run_list)
                print("ImageNet Pretrained Weights Loaded Initially\n\n")
            except KeyboardInterrupt:
                print("No pretrained ImageNet weights exist. Skipping...\n\n")

        # 确定预通道参数
        if net_flag == 'vgg':
            MEAN = VGG_MEAN
        if net_flag == 'shuffle':
            MEAN = SHUFFLE_MEAN

        train_cost_time_mean = []
        val_cost_time_mean = []
        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            with tf.device('/cpu:0'):
                gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch(
                    CFG.TRAIN.BATCH_SIZE)
                gt_imgs = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_LINEAR)
                    for tmp in gt_imgs
                ]

                gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]
                binary_gt_labels = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_NEAREST)
                    for tmp in binary_gt_labels
                ]
                binary_gt_labels = [
                    np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels
                ]
                instance_gt_labels = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_NEAREST)
                    for tmp in instance_gt_labels
                ]
            phase_train = 'train'

            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: phase_train})

            if math.isnan(c) or math.isnan(binary_loss) or math.isnan(
                    instance_loss):
                log.error('cost is: {:.5f}'.format(c))
                log.error('binary cost is: {:.5f}'.format(binary_loss))
                log.error('instance cost is: {:.5f}'.format(instance_loss))
                cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('nan_instance_label.png', instance_gt_labels[0])
                cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255)
                return

            if epoch % 100 == 0:
                cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255)
                cv2.imwrite('instance_label.png', instance_gt_labels[0])
                cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255)

                for i in range(4):
                    embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i])
                embedding_image = np.array(embedding[0], np.uint8)
                cv2.imwrite('embedding.png', embedding_image)

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            with tf.device('/cpu:0'):
                gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \
                    = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
                gt_imgs_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_LINEAR)
                    for tmp in gt_imgs_val
                ]
                gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
                binary_gt_labels_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp) for tmp in binary_gt_labels_val
                ]
                binary_gt_labels_val = [
                    np.expand_dims(tmp, axis=-1)
                    for tmp in binary_gt_labels_val
                ]
                instance_gt_labels_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_NEAREST)
                    for tmp in instance_gt_labels_val
                ]
            phase_val = 'test'

            t_start_val = time.time()
            c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \
                sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss],
                         feed_dict={input_tensor: gt_imgs_val,
                                    binary_label_tensor: binary_gt_labels_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    phase: phase_val})

            if epoch % 100 == 0:
                cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN)

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        epoch + 1, c, binary_loss, instance_loss,
                        train_accuracy, np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                log.info(
                    'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} '
                    'mean_cost_time= {:5f}s '.format(
                        epoch + 1, c_val, val_binary_seg_loss,
                        val_instance_seg_loss, val_accuracy,
                        np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 10000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    sess.close()

    return
Пример #8
0
def train_net(dataset_dir, weights_path=None, net_flag='vgg', initial_step=0):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """

    train_dataset_file = ops.join(dataset_dir, '7-3_random_train.txt')
    val_dataset_file = ops.join(dataset_dir, '7-3_random_val.txt')

    # train_dataset_file = ops.join(dataset_dir, '9-1_train.txt')
    # val_dataset_file = ops.join(dataset_dir, '9-1_val.txt')

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    with tf.device('/gpu:0'):
        # with tf.device('/cpu:0'):
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          CFG.TRAIN.BATCH_SIZE,
                                          CFG.TRAIN.IMG_HEIGHT,
                                          CFG.TRAIN.IMG_WIDTH, 3
                                      ],
                                      name='input_tensor')
        binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                             shape=[
                                                 CFG.TRAIN.BATCH_SIZE,
                                                 CFG.TRAIN.IMG_HEIGHT,
                                                 CFG.TRAIN.IMG_WIDTH, 1
                                             ],
                                             name='binary_input_label')
        instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                               shape=[
                                                   CFG.TRAIN.BATCH_SIZE,
                                                   CFG.TRAIN.IMG_HEIGHT,
                                                   CFG.TRAIN.IMG_WIDTH
                                               ],
                                               name='instance_input_label')
        # binary_seg_img_tensor = tf.placeholder(dtype=tf.uint8,
        #                                      shape=[CFG.TRAIN.IMG_HEIGHT,
        #                                             CFG.TRAIN.IMG_WIDTH, 1])

        phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

        net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

        # calculate the loss
        compute_ret = net.compute_loss(input_tensor=input_tensor,
                                       binary_label=binary_label_tensor,
                                       instance_label=instance_label_tensor,
                                       name='lanenet_model')
        total_loss = compute_ret['total_loss']
        binary_seg_loss = compute_ret['binary_seg_loss']
        disc_loss = compute_ret['discriminative_loss']
        pix_embedding = compute_ret['instance_seg_logits']

        counts = compute_ret['counts']

        # calculate the accuracy
        out_logits = compute_ret['binary_seg_logits']
        out_logits = tf.nn.softmax(logits=out_logits)
        out_logits_out = tf.argmax(
            out_logits,
            axis=-1)  # transform a 2-channel feature map into a binary image
        out = tf.argmax(out_logits, axis=-1)
        out = tf.expand_dims(out, axis=-1)

        idx = tf.where(tf.equal(binary_label_tensor,
                                1))  # select the Positive Pixels in GT image
        pix_cls_ret = tf.gather_nd(
            out, idx)  # slice out the corresponding pixels in output image
        accuracy = tf.count_nonzero(pix_cls_ret)  # True Positive
        accuracy = tf.divide(accuracy,
                             tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))
        # Accuracy = TP / (TP + FN), ie. Recall

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE,
                                                   global_step,
                                                   CFG.TRAIN.LR_DECAY_STEPS,
                                                   CFG.TRAIN.LR_DECAY_RATE,
                                                   staircase=True)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=0.9).minimize(loss=total_loss,
                                       var_list=tf.trainable_variables(),
                                       global_step=global_step)
            # optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/tusimple_lanenet'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)
    img_output_dir = f'output/{net_flag}_{train_start_time}'

    # Set tf restorer
    mobile_pretrained_path = 'model/mobilenet/mobilenet_v2_1.0_224.ckpt'
    reader = tf.train.NewCheckpointReader(mobile_pretrained_path)
    restore_dict = dict()
    for v in tf.trainable_variables():
        s = v.name.split(':')[0]
        i = s.find('MobilenetV2')
        if i != -1:
            tensor_name = s[i:]
            # print(tensor_name)
            if reader.has_tensor(tensor_name):
                # print('has tensor ', tensor_name)
                restore_dict[tensor_name] = v

    pretrained_saver = tf.train.Saver(restore_dict, name="pretrained_saver")

    # Set tf summary
    tboard_save_path = f'tboard/tusimple_lanenet/{net_flag}/{train_start_time}'
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                            tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss',
                                                   tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(
        name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate',
                                             tensor=learning_rate)

    # train_bin_seg_img = tf.summary.image('Train Binary Segmentation', tensor=binary_seg_img_tensor)
    # train_raw_img = tf.summary.image('Train Raw Image', gt_imgs[0] + VGG_MEAN)
    # val_bin_seg_img = tf.summary.image('Binary Segmentation', )
    # val_bin_seg_img = tf.summary.image('Binary Segmentation', )
    # val_bin_seg_img = tf.summary.image('Binary Segmentation', )
    # val_bin_seg_img = tf.summary.image('Binary Segmentation', )

    # cv2.imwrite(f'output/{train_start_time}_{net_flag}_image.png', gt_imgs[0] + VGG_MEAN)
    # cv2.imwrite(f'output/{train_start_time}_{net_flag}_binary_label.png', binary_gt_labels[0] * 255)
    # cv2.imwrite(f'output/{train_start_time}_{net_flag}_instance_label.png', instance_gt_labels[0])
    # cv2.imwrite(f'output/{train_start_time}_{net_flag}_binary_seg_img.png', binary_seg_img[0] * 255)
    #
    # cv2.imwrite(f'output/{train_start_time}_{net_flag}_embedding.png', embedding_image)
    #
    # cv2.imwrite(f'output/{train_start_time}_{net_flag}_image_VAL.png', gt_imgs_val[0] + VGG_MEAN)
    # cv2.imwrite(f'output/{train_start_time}_{net_flag}_binary_seg_img_VAL.png', val_binary_seg_img[0] * 255)

    train_merge_summary_op = tf.summary.merge([
        train_accuracy_scalar, train_cost_scalar, learning_rate_scalar,
        train_binary_seg_loss_scalar, train_instance_seg_loss_scalar
    ])  # , train_bin_seg_img
    val_merge_summary_op = tf.summary.merge([
        val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
        val_instance_seg_loss_scalar
    ])

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    # summary_writer = tf.summary.FileWriter(tboard_save_path)
    # summary_writer.add_graph(sess.graph)

    summary_writer = tf.summary.FileWriter(tboard_save_path, sess.graph)

    # Set the training parameters
    train_steps = CFG.TRAIN.STEPS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/lanenet_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        if net_flag == 'vgg' and weights_path is None:
            pretrained_weights = np.load('./data/vgg16.npy',
                                         encoding='latin1').item()
            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue
        elif net_flag == 'mobile' and weights_path is None:
            pass
            # pretrained_saver.restore(sess=sess, save_path=mobile_pretrained_path)

        train_cost_time_mean = []
        val_cost_time_mean = []
        for step in range(int(initial_step), train_steps):
            # training part
            t_start = time.time()

            with tf.device('/gpu:0'):
                raw_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch(
                    CFG.TRAIN.BATCH_SIZE)
                # gt_imgs = [cv2.resize(tmp,
                #                       dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                #                       dst=tmp,
                #                       interpolation=cv2.INTER_LINEAR)
                #            for tmp in gt_imgs]

                gt_imgs = [tmp - VGG_MEAN for tmp in raw_imgs]
                # binary_gt_labels = [cv2.resize(tmp,
                #                                dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                #                                dst=tmp,
                #                                interpolation=cv2.INTER_NEAREST)
                #                     for tmp in binary_gt_labels]
                binary_gt_labels = [
                    np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels
                ]
                # instance_gt_labels = [cv2.resize(tmp,
                #                                  dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                #                                  dst=tmp,
                #                                  interpolation=cv2.INTER_NEAREST)
                #                       for tmp in instance_gt_labels]
            phase_train = 'train'

            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img, ct = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out, counts],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: phase_train})

            # binary_label_tensor = tf.assign(tf.multiply(binary_seg_img[0], 255))

            print(ct)
            if math.isnan(c) or math.isnan(binary_loss) or math.isnan(
                    instance_loss):
                log.error('cost is: {:.5f}'.format(c))
                log.error('binary cost is: {:.5f}'.format(binary_loss))
                log.error('instance cost is: {:.5f}'.format(instance_loss))
                # cv2.imwrite(f'output/{train_start_time}_{net_flag}_nan_image.png', gt_imgs[0] + VGG_MEAN)
                # cv2.imwrite(f'output/{train_start_time}_{net_flag}_nan_instance_label.png', instance_gt_labels[0])
                # cv2.imwrite(f'output/{train_start_time}_{net_flag}_nan_binary_label.png', binary_gt_labels[0] * 255)
                return

            if step % 50 == 0:
                if not os.path.exists(img_output_dir):
                    os.mkdir(img_output_dir)
                print("Image Updated...")
                cv2.imwrite(
                    img_output_dir +
                    f'/{train_start_time}_{net_flag}_TRAIN_raw.png',
                    gt_imgs[0] + VGG_MEAN)
                cv2.imwrite(
                    img_output_dir +
                    f'/{train_start_time}_{net_flag}_TRAIN_binary_label.png',
                    binary_gt_labels[0] * 255)
                cv2.imwrite(
                    img_output_dir +
                    f'/{train_start_time}_{net_flag}_TRAIN_instance_label.png',
                    instance_gt_labels[0])
                cv2.imwrite(
                    img_output_dir +
                    f'/{train_start_time}_{net_flag}_TRAIN_bin_seg.png',
                    binary_seg_img[0] * 255)

                for i in range(4):
                    embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i])
                embedding_image = np.array(embedding[0], np.uint8)
                cv2.imwrite(
                    img_output_dir +
                    f'/{train_start_time}_{net_flag}_TRAIN_embedding.png',
                    embedding_image)

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=train_summary, global_step=step)

            # validation part
            with tf.device('/gpu:0'):
                gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \
                    = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
                # gt_imgs_val = [cv2.resize(tmp,
                #                           dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                #                           dst=tmp,
                #                           interpolation=cv2.INTER_LINEAR)
                #                for tmp in gt_imgs_val]
                gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
                # binary_gt_labels_val = [cv2.resize(tmp,
                #                                    dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                #                                    dst=tmp)
                #                         for tmp in binary_gt_labels_val]
                binary_gt_labels_val = [
                    np.expand_dims(tmp, axis=-1)
                    for tmp in binary_gt_labels_val
                ]
                # instance_gt_labels_val = [cv2.resize(tmp,
                #                                      dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                #                                      dst=tmp,
                #                                      interpolation=cv2.INTER_NEAREST)
                #                           for tmp in instance_gt_labels_val]
            phase_val = 'test'

            t_start_val = time.time()
            c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss, embedding, val_binary_seg_img, val_ct = \
                sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss, pix_embedding, out_logits_out, counts],
                         feed_dict={input_tensor: gt_imgs_val,
                                    binary_label_tensor: binary_gt_labels_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    phase: phase_val})

            if step % 50 == 0:
                if not os.path.exists(img_output_dir):
                    os.mkdir(img_output_dir)
                for i in range(CFG.TRAIN.VAL_BATCH_SIZE):
                    cv2.imwrite(
                        img_output_dir +
                        f'/{train_start_time}_{net_flag}_VAL_{i}_raw.png',
                        gt_imgs_val[i] + VGG_MEAN)
                    cv2.imwrite(
                        img_output_dir +
                        f'/{train_start_time}_{net_flag}_VAL_{i}_bin_seg.png',
                        val_binary_seg_img[i] * 255)
                    for j in range(4):
                        embedding[i][:, :, j] = minmax_scale(embedding[i][:, :,
                                                                          j])
                    embedding_image = np.array(embedding[i], np.uint8)
                    cv2.imwrite(
                        img_output_dir +
                        f'/{train_start_time}_{net_flag}_VAL_{i}_embedding.png',
                        embedding_image)

            summary_writer.add_summary(val_summary, global_step=step)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if step % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Step: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        step + 1, c, binary_loss, instance_loss,
                        train_accuracy, np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if step % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                log.info(
                    'Step_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} '
                    'mean_cost_time= {:5f}s '.format(
                        step + 1, c_val, val_binary_seg_loss,
                        val_instance_seg_loss, val_accuracy,
                        np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if step % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=step)
    sess.close()

    return
Пример #9
0
def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    # 所有训练样本列表
    train_dataset_file = ops.join(dataset_dir, 'train.txt')
    val_dataset_file = ops.join(dataset_dir, 'val.txt')

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                         CFG.TRAIN.IMG_WIDTH, 3],
                                  name='input_tensor')
    binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                         shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                                CFG.TRAIN.IMG_WIDTH, 1],
                                         name='binary_input_label')
    instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                           shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                                  CFG.TRAIN.IMG_WIDTH],
                                           name='instance_input_label')
    phase = tf.placeholder(dtype=tf.bool, shape=None, name='net_phase')

    net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

    # calculate the loss
    compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor,
                                   instance_label=instance_label_tensor, name='lanenet_model')
    total_loss = compute_ret['total_loss']
    binary_seg_loss = compute_ret['binary_seg_loss']
    disc_loss = compute_ret['discriminative_loss']
    pix_embedding = compute_ret['instance_seg_logits']

    # calculate the accuracy
    out_logits = compute_ret['binary_seg_logits']
    out_logits = tf.nn.softmax(logits=out_logits)
    out_logits_out = tf.argmax(out_logits, axis=-1)
    out = tf.argmax(out_logits, axis=-1)
    out = tf.expand_dims(out, axis=-1)

    idx = tf.where(tf.equal(binary_label_tensor, 1))
    pix_cls_ret = tf.gather_nd(out, idx)
    recall = tf.count_nonzero(pix_cls_ret)
    recall = tf.divide(recall, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

    idx = tf.where(tf.equal(binary_label_tensor, 0))
    pix_cls_ret = tf.gather_nd(out, idx)
    precision = tf.subtract(tf.cast(tf.shape(pix_cls_ret)[0], tf.int64), tf.count_nonzero(pix_cls_ret))
    precision = tf.divide(precision, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

    accuracy = tf.divide(2.0, tf.divide(1.0, recall) + tf.divide(1.0, precision))

    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step,
                                               100000, 0.1, staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
        gradients = optimizer.compute_gradients(total_loss)
        capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]
        train_op = optimizer.apply_gradients(capped_gradients, global_step=global_step)

    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/tusimple_lanenet'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set tf summary
    tboard_save_path = 'tboard/tusimple_lanenet/{:s}'.format(net_flag)
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge([train_accuracy_scalar, train_cost_scalar,
                                               learning_rate_scalar, train_binary_seg_loss_scalar,
                                               train_instance_seg_loss_scalar])
    val_merge_summary_op = tf.summary.merge([val_accuracy_scalar, val_cost_scalar,
                                             val_binary_seg_loss_scalar, val_instance_seg_loss_scalar])

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
                             name='{:s}/lanenet_model.pbtxt'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        if net_flag == 'vgg' and weights_path is None:
            pretrained_weights = np.load(
                './data/vgg16.npy',
                encoding='latin1').item()

            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        train_cost_time_mean = []
        for epoch in range(train_epochs):
            # training part
            t_start = time.time()
            gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch(CFG.TRAIN.BATCH_SIZE)
            gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]

            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \
                sess.run([train_op, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: True})

            if math.isnan(c) or math.isnan(binary_loss) or math.isnan(instance_loss):
                log.error('cost is: {:.5f}'.format(c))
                log.error('binary cost is: {:.5f}'.format(binary_loss))
                log.error('instance cost is: {:.5f}'.format(instance_loss))
                log.error('gradients is: {}'.format(g))
                cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('nan_instance_label.png', instance_gt_labels[0])
                cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255)
                return

            if epoch % 100 == 0:
                cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255)
                cv2.imwrite('instance_label.png', instance_gt_labels[0])
                cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255)

                for i in range(4):
                    embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i])
                embedding_image = np.array(embedding[0], np.uint8)
                cv2.imwrite('embedding.png', embedding_image)

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=train_summary, global_step=epoch)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info('Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                         ' mean_cost_time= {:5f}s '.
                         format(epoch + 1, c, binary_loss, instance_loss, train_accuracy,
                                np.mean(train_cost_time_mean)))
                train_cost_time_mean = []

            if epoch % 1000 == 0:
                saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
    sess.close()

    return
Пример #10
0
def train_net(dataset_dir, weights_path=None):

    train_dataset_file = ops.join(dataset_dir, 'train.txt')
    val_dataset_file = ops.join(dataset_dir, 'val.txt')

    assert ops.exists(train_dataset_file), '{:s} 不存在'.format(
        train_dataset_file)
    assert ops.exists(val_dataset_file), '{:s} 不存在'.format(val_dataset_file)

    # 创建训练集和验证集实例train_dataset,val_dataset
    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    # Tensorflow的创建Graph过程
    with tf.device('/gpu:0'):
        # input_tensor:输入张量,binary_label_tensor:二值分割标签,instance_label_tensor:实例分割标签,phase:训练(测试)阶段
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          CFG.TRAIN.BATCH_SIZE,
                                          CFG.TRAIN.IMG_HEIGHT,
                                          CFG.TRAIN.IMG_WIDTH, 3
                                      ],
                                      name='input_tensor')
        binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                             shape=[
                                                 CFG.TRAIN.BATCH_SIZE,
                                                 CFG.TRAIN.IMG_HEIGHT,
                                                 CFG.TRAIN.IMG_WIDTH, 1
                                             ],
                                             name='binary_input_label')
        instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                               shape=[
                                                   CFG.TRAIN.BATCH_SIZE,
                                                   CFG.TRAIN.IMG_HEIGHT,
                                                   CFG.TRAIN.IMG_WIDTH
                                               ],
                                               name='instance_input_label')
        phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

        # 创建LaneNet网络架构
        net = lanenet_merge_model.LaneNet(phase=phase)

        # 计算损失
        compute_ret = net.compute_loss(input_tensor=input_tensor,
                                       binary_label=binary_label_tensor,
                                       instance_label=instance_label_tensor,
                                       name='lanenet_model')
        total_loss = compute_ret['total_loss']
        binary_seg_loss = compute_ret['binary_seg_loss']
        disc_loss = compute_ret['discriminative_loss']
        pix_embedding = compute_ret['instance_seg_logits']

        # 计算准确度
        out_logits = compute_ret['binary_seg_logits']
        out_logits = tf.nn.softmax(logits=out_logits)
        out_logits_out = tf.argmax(out_logits, axis=-1)
        out = tf.argmax(out_logits, axis=-1)
        out = tf.expand_dims(out, axis=-1)
        idx = tf.where(tf.equal(binary_label_tensor, 1))
        pix_cls_ret = tf.gather_nd(out, idx)
        accuracy = tf.count_nonzero(pix_cls_ret)
        accuracy = tf.divide(accuracy,
                             tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

        # 设置训练迭代步数,学习率以及优化器
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE,
                                                   global_step,
                                                   CFG.TRAIN.LR_DECAY_STEPS,
                                                   CFG.TRAIN.LR_DECAY_RATE,
                                                   staircase=True)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=CFG.TRAIN.MOMENTUM).minimize(
                    loss=total_loss,
                    var_list=tf.trainable_variables(),
                    global_step=global_step)
    # 设置Tensorflow的Saver,用以保存Model
    saver = tf.train.Saver()
    # 设置Model的保存目录
    model_save_dir = 'model'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    # 设置Model的名称(以训练开始时间为后缀)
    model_name = 'lanenet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # 设置Tensorflow的Summary,用以保存tboard
    # 设置tboard的目录(以训练开始时间为后缀)
    tboard_save_path = 'tboard/lanenet_{:s}'.format(str(train_start_time))
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                              tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                            tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(
        name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss',
                                                   tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(
        name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(
        name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate',
                                             tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge([
        train_accuracy_scalar, train_cost_scalar, learning_rate_scalar,
        train_binary_seg_loss_scalar, train_instance_seg_loss_scalar
    ])
    val_merge_summary_op = tf.summary.merge([
        val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
        val_instance_seg_loss_scalar
    ])

    # 设置Session的全局配置
    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=False)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'
    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # 设置训练阶段的全局参数,并打印出来
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    # Tensorflow的打开Session过程
    with sess.as_default():
        # 将Graph的信息保存在lanenet_model.pb文件中
        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/lanenet_model.pb'.format(model_save_dir))

        # 如果不存在预训练的模型,则初始化参数从头开始训练,否则加载预训练模型进行迁移学习
        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        train_cost_time_mean = []
        val_cost_time_mean = []

        for epoch in range(train_epochs):
            # 训练部分
            t_start = time.time()

            with tf.device('/cpu:0'):
                # gt_imgs代表原图,binary_gt_labels代表二值分割标签,instance_gt_labels代表实例分割标签
                gt_imgs, binary_gt_labels, instance_gt_labels = train_dataset.next_batch(
                    CFG.TRAIN.BATCH_SIZE)
                gt_imgs = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_LINEAR)
                    for tmp in gt_imgs
                ]

                gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]

                binary_gt_labels = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_NEAREST)
                    for tmp in binary_gt_labels
                ]

                binary_gt_labels = [
                    np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels
                ]
                instance_gt_labels = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_NEAREST)
                    for tmp in instance_gt_labels
                ]
            phase_train = 'train'

            # 训练LaneNet网络
            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: phase_train})

            # 异常处理:当损失不为数字时,打印异常并保存当前结果
            if math.isnan(c) or math.isnan(binary_loss) or math.isnan(
                    instance_loss):
                log.error('Epoch: {:d} Total cost: {:}'.format(epoch + 1, c))
                log.error('Epoch: {:d} Total binary cost: {:}'.format(
                    epoch + 1, binary_loss))
                log.error('Epoch: {:d} Total instance cost: {:}'.format(
                    epoch + 1, instance_loss))
                cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('nan_instance_label.png', instance_gt_labels[0])
                cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255)
                return

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            # tboard记录训练日志
            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # 验证部分
            with tf.device('/cpu:0'):
                gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \
                    = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
                gt_imgs_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_LINEAR)
                    for tmp in gt_imgs_val
                ]
                gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
                binary_gt_labels_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp) for tmp in binary_gt_labels_val
                ]
                binary_gt_labels_val = [
                    np.expand_dims(tmp, axis=-1)
                    for tmp in binary_gt_labels_val
                ]
                instance_gt_labels_val = [
                    cv2.resize(tmp,
                               dsize=(CFG.TRAIN.IMG_WIDTH,
                                      CFG.TRAIN.IMG_HEIGHT),
                               dst=tmp,
                               interpolation=cv2.INTER_NEAREST)
                    for tmp in instance_gt_labels_val
                ]
            phase_val = 'test'

            t_start_val = time.time()
            # 验证LaneNet网络
            c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \
                sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss],
                         feed_dict={input_tensor: gt_imgs_val,
                                    binary_label_tensor: binary_gt_labels_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    phase: phase_val})

            # tboard记录验证日志
            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            # 打印训练日志
            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        epoch + 1, c, binary_loss, instance_loss,
                        train_accuracy, np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            # 打印验证日志
            if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                log.info(
                    'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} '
                    'mean_cost_time= {:5f}s '.format(
                        epoch + 1, c_val, val_binary_seg_loss,
                        val_instance_seg_loss, val_accuracy,
                        np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            # 保存Model
            if epoch % 2000 == 0 and epoch != 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    # 关闭Session
    sess.close()

    return