def convert_ckpt_into_pb_file(ckpt_file_path, pb_file_path):
    """

    :param ckpt_file_path:
    :param pb_file_path:
    :return:
    """
    # construct compute graph
    with tf.variable_scope('lanenet'):
        input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')

    net = lanenet.LaneNet(phase='test', net_flag='vgg')
    binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')

    with tf.variable_scope('lanenet/'):
        binary_seg_ret = tf.cast(binary_seg_ret, dtype=tf.float32)
        binary_seg_ret = tf.squeeze(binary_seg_ret, axis=0, name='final_binary_output')
        instance_seg_ret = tf.squeeze(instance_seg_ret, axis=0, name='final_pixel_embedding_output')

    # create a session
    saver = tf.train.Saver()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.85
    sess_config.gpu_options.allow_growth = False
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess, ckpt_file_path)

        converted_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            input_graph_def=sess.graph.as_graph_def(),
            output_node_names=[
                'lanenet/input_tensor',
                'lanenet/final_binary_output',
                'lanenet/final_pixel_embedding_output'
            ]
        )

        with tf.gfile.GFile(pb_file_path, "wb") as f:
            f.write(converted_graph_def.SerializeToString())
Beispiel #2
0
def test_lanenet_batch(src_dir, weights_path, save_dir):
    """

    :param src_dir:
    :param weights_path:
    :param save_dir:
    :return:
    """
    assert ops.exists(src_dir), '{:s} not exist'.format(src_dir)

    os.makedirs(save_dir, exist_ok=True)

    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[1, 256, 512, 3],
                                  name='input_tensor')

    net = lanenet.LaneNet(phase='test', net_flag='vgg')
    binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor,
                                                     name='lanenet_model')

    postprocessor = lanenet_postprocess.LaneNetPostProcessor()

    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    with sess.as_default():

        saver.restore(sess=sess, save_path=weights_path)

        image_list = glob.glob('{:s}/**/*.jpg'.format(src_dir), recursive=True)
        avg_time_cost = []

        vid_writer = None
        # frame_count = 0
        # ignore_count_1 = 0
        for index, image_path in tqdm.tqdm(enumerate(image_list),
                                           total=len(image_list)):

            image = cv2.imread(image_path, cv2.IMREAD_COLOR)
            image_vis = image
            image = cv2.resize(image, (512, 256),
                               interpolation=cv2.INTER_LINEAR)
            image = image / 127.5 - 1.0

            t_start = time.time()
            binary_seg_image, instance_seg_image = sess.run(
                [binary_seg_ret, instance_seg_ret],
                feed_dict={input_tensor: [image]})
            avg_time_cost.append(time.time() - t_start)

            postprocess_result = postprocessor.postprocess(
                binary_seg_result=binary_seg_image[0],
                instance_seg_result=instance_seg_image[0],
                source_image=image_vis)

            if index % 100 == 0:
                log.info(
                    'Mean inference time every single image: {:.5f}s'.format(
                        np.mean(avg_time_cost)))
                avg_time_cost.clear()

            input_image_dir = ops.split(image_path.split('clips')[1])[0][1:]
            input_image_name = ops.split(image_path)[1]
            output_image_dir = ops.join(save_dir, input_image_dir)
            os.makedirs(output_image_dir, exist_ok=True)
            output_image_path = ops.join(output_image_dir, input_image_name)
            if ops.exists(output_image_path):
                continue
            # print(output_image_path)

            try:
                cv2.imwrite(output_image_path,
                            postprocess_result['source_image'])
            except:
                continue
    vid_writer.release()
    return
def train_lanenet_multi_gpu(dataset_dir, weights_path=None, net_flag='vgg'):
    """
    train lanenet with multi gpu
    :param dataset_dir:
    :param weights_path:
    :param net_flag:
    :return:
    """
    # set lanenet dataset
    train_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='val')

    # set lanenet
    train_net = lanenet.LaneNet(net_flag=net_flag, phase='train', reuse=False)
    val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True)

    # set compute graph node
    train_images, train_binary_labels, train_instance_labels = train_dataset.inputs(
        CFG.TRAIN.BATCH_SIZE, 1)
    val_images, val_binary_labels, val_instance_labels = val_dataset.inputs(
        CFG.TRAIN.VAL_BATCH_SIZE, 1)

    # set average container
    tower_grads = []
    train_tower_loss = []
    val_tower_loss = []
    batchnorm_updates = None
    train_summary_op_updates = None

    # set lr
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.polynomial_decay(
        learning_rate=CFG.TRAIN.LEARNING_RATE,
        global_step=global_step,
        decay_steps=CFG.TRAIN.EPOCHS,
        power=0.9)

    # set optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=CFG.TRAIN.MOMENTUM)

    # set distributed train op
    with tf.variable_scope(tf.get_variable_scope()):
        for i in range(CFG.TRAIN.GPU_NUM):
            with tf.device('/gpu:{:d}'.format(i)):
                with tf.name_scope('tower_{:d}'.format(i)) as _:
                    train_loss, grads = compute_net_gradients(
                        train_images, train_binary_labels,
                        train_instance_labels, train_net, optimizer)

                    # Only use the mean and var in the first gpu tower to update the parameter
                    if i == 0:
                        batchnorm_updates = tf.get_collection(
                            tf.GraphKeys.UPDATE_OPS)
                        train_summary_op_updates = tf.get_collection(
                            tf.GraphKeys.SUMMARIES)
                    tower_grads.append(grads)
                    train_tower_loss.append(train_loss)

                with tf.name_scope('validation_{:d}'.format(i)) as _:
                    val_loss, _ = compute_net_gradients(
                        val_images, val_binary_labels, val_instance_labels,
                        val_net, optimizer)
                    val_tower_loss.append(val_loss)

    grads = average_gradients(tower_grads)
    avg_train_loss = tf.reduce_mean(train_tower_loss)
    avg_val_loss = tf.reduce_mean(val_tower_loss)

    # Track the moving averages of all trainable variables
    variable_averages = tf.train.ExponentialMovingAverage(
        CFG.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step)
    variables_to_average = tf.trainable_variables(
    ) + tf.moving_average_variables()
    variables_averages_op = variable_averages.apply(variables_to_average)

    # Group all the op needed for training
    batchnorm_updates_op = tf.group(*batchnorm_updates)
    apply_gradient_op = optimizer.apply_gradients(grads,
                                                  global_step=global_step)
    train_op = tf.group(apply_gradient_op, variables_averages_op,
                        batchnorm_updates_op)

    # Set tf summary save path
    tboard_save_path = 'tboard/tusimple_lanenet_multi_gpu_{:s}'.format(
        net_flag)
    os.makedirs(tboard_save_path, exist_ok=True)

    summary_writer = tf.summary.FileWriter(tboard_save_path)

    avg_train_loss_scalar = tf.summary.scalar(name='average_train_loss',
                                              tensor=avg_train_loss)
    avg_val_loss_scalar = tf.summary.scalar(name='average_val_loss',
                                            tensor=avg_val_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate_scalar',
                                             tensor=learning_rate)

    train_merge_summary_op = tf.summary.merge(
        [avg_train_loss_scalar, learning_rate_scalar] +
        train_summary_op_updates)
    val_merge_summary_op = tf.summary.merge([avg_val_loss_scalar])

    # set tensorflow saver
    saver = tf.train.Saver()
    model_save_dir = 'model/tusimple_lanenet_multi_gpu_{:s}'.format(net_flag)
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # set sess config
    sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM},
                                 allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    sess = tf.Session(config=sess_config)

    summary_writer.add_graph(sess.graph)

    with sess.as_default():

        tf.train.write_graph(
            graph_or_graph_def=sess.graph,
            logdir='',
            name='{:s}/lanenet_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        train_cost_time_mean = []
        val_cost_time_mean = []

        for epoch in range(train_epochs):

            # training part
            t_start = time.time()

            _, train_loss_value, train_summary, lr = \
                sess.run(
                    fetches=[train_op, avg_train_loss,
                             train_merge_summary_op, learning_rate]
                )

            if math.isnan(train_loss_value):
                log.error('Train loss is nan')
                return

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)

            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            t_start_val = time.time()

            val_loss_value, val_summary = \
                sess.run(fetches=[avg_val_loss, val_merge_summary_op])

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info('Epoch_Train: {:d} total_loss= {:6f} '
                         'lr= {:6f} mean_cost_time= {:5f}s '.format(
                             epoch + 1, train_loss_value, lr,
                             np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                log.info('Epoch_Val: {:d} total_loss= {:6f}'
                         ' mean_cost_time= {:5f}s '.format(
                             epoch + 1, val_loss_value,
                             np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
    return
Beispiel #4
0
def test_lanenet(weights_path, in_image=None, image_path=None, session=None):
    """
    Tests the lanenet model on the image passes as an argument.
    :param image_path: path to the input image
    :param weights_path: path to the weights of the lanenet model
    :param in_image: input image
    :return: output image with detected lanes
    """

    # t_start = time.time()
    if in_image is None:
        # make sure that the path is valid
        assert ops.exists(image_path), '{:s} not exist'.format(image_path)

        # log.info('Start reading image and preprocessing')
        # read image from that path
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    else:
        image = in_image.copy()

    image_vis = image
    # Resize the image to the standard dimensions
    image = cv2.resize(image, (512, 256), interpolation=cv2.INTER_LINEAR)
    image = image / 127.5 - 1.0
    # log.info('Image load complete, cost time: {:.5f}s'.format(time.time() - t_start))

    # create an empty place holder of a specified size and of type float
    input_tensor = tf.placeholder(dtype=tf.float32,
                                  shape=[1, 256, 512, 3],
                                  name='input_tensor')

    # Initialize the Lanenet model
    net = lanenet.LaneNet(phase='test', net_flag='vgg')

    # Make predictions using lanenet
    binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor,
                                                     name='lanenet_model')

    # Instantiate the postprocessor
    postprocessor = lanenet_postprocess.LaneNetPostProcessor()

    # Save the current instance
    saver = tf.train.Saver()

    # Set sess configuration
    # sess_config = tf.ConfigProto(device_count={'GPU': 0})
    # sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    # sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    # sess_config.gpu_options.allocator_type = 'BFC'

    # sess = tf.Session(config=sess_config)

    sess = session

    output_img = None

    with sess.as_default():

        saver.restore(sess=sess, save_path=weights_path)

        # t_start = time.time()
        binary_seg_image, instance_seg_image = sess.run(
            [binary_seg_ret, instance_seg_ret],
            feed_dict={input_tensor: [image]})
        # t_cost = time.time() - t_start
        # log.info('Single imgae inference cost time: {:.5f}s'.format(t_cost))

        postprocess_result = postprocessor.postprocess(
            binary_seg_result=binary_seg_image[0],
            instance_seg_result=instance_seg_image[0],
            source_image=image_vis)
        # mask_image = postprocess_result['mask_image']

        for i in range(CFG.TRAIN.EMBEDDING_FEATS_DIMS):
            instance_seg_image[0][:, :,
                                  i] = minmax_scale(instance_seg_image[0][:, :,
                                                                          i])
        # embedding_image = np.array(instance_seg_image[0], np.uint8)

        output_img = postprocess_result['source_image']

        # plt.figure('mask_image')
        # plt.imshow(mask_image[:, :, (2, 1, 0)])
        # plt.figure('src_image')
        # plt.imshow(image_vis[:, :, (2, 1, 0)])
        # plt.figure('instance_image')
        # plt.imshow(embedding_image[:, :, (2, 1, 0)])
        # plt.figure('binary_image')
        # plt.imshow(binary_seg_image[0] * 255, cmap='gray')
        # plt.show()
        #
        # cv2.imwrite('instance_mask_image.png', mask_image)
        # cv2.imwrite('source_image.png', postprocess_result['source_image'])
        # cv2.imwrite('binary_mask_image.png', binary_seg_image[0] * 255)

    # sess.close()

    return output_img
def train_lanenet(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = lanenet_data_feed_pipline.LaneNetDataFeeder(
        dataset_dir=dataset_dir, flags='val')

    with tf.device('/gpu:1'):
        # set lanenet
        train_net = lanenet.LaneNet(net_flag=net_flag,
                                    phase='train',
                                    reuse=False)
        val_net = lanenet.LaneNet(net_flag=net_flag, phase='val', reuse=True)

        # set compute graph node for training
        train_images, train_binary_labels, train_instance_labels = train_dataset.inputs(
            CFG.TRAIN.BATCH_SIZE, 1)

        train_compute_ret = train_net.compute_loss(
            input_tensor=train_images,
            binary_label=train_binary_labels,
            instance_label=train_instance_labels,
            name='lanenet_model')
        train_total_loss = train_compute_ret['total_loss']
        train_binary_seg_loss = train_compute_ret['binary_seg_loss']
        train_disc_loss = train_compute_ret['discriminative_loss']
        train_pix_embedding = train_compute_ret['instance_seg_logits']

        train_prediction_logits = train_compute_ret['binary_seg_logits']
        train_prediction_score = tf.nn.softmax(logits=train_prediction_logits)
        train_prediction = tf.argmax(train_prediction_score, axis=-1)

        train_accuracy = evaluate_model_utils.calculate_model_precision(
            train_compute_ret['binary_seg_logits'], train_binary_labels)
        train_fp = evaluate_model_utils.calculate_model_fp(
            train_compute_ret['binary_seg_logits'], train_binary_labels)
        train_fn = evaluate_model_utils.calculate_model_fn(
            train_compute_ret['binary_seg_logits'], train_binary_labels)
        train_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=train_prediction)
        train_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=train_pix_embedding)

        train_cost_scalar = tf.summary.scalar(name='train_cost',
                                              tensor=train_total_loss)
        train_accuracy_scalar = tf.summary.scalar(name='train_accuracy',
                                                  tensor=train_accuracy)
        train_binary_seg_loss_scalar = tf.summary.scalar(
            name='train_binary_seg_loss', tensor=train_binary_seg_loss)
        train_instance_seg_loss_scalar = tf.summary.scalar(
            name='train_instance_seg_loss', tensor=train_disc_loss)
        train_fn_scalar = tf.summary.scalar(name='train_fn', tensor=train_fn)
        train_fp_scalar = tf.summary.scalar(name='train_fp', tensor=train_fp)
        train_binary_seg_ret_img = tf.summary.image(
            name='train_binary_seg_ret',
            tensor=train_binary_seg_ret_for_summary)
        train_embedding_feats_ret_img = tf.summary.image(
            name='train_embedding_feats_ret',
            tensor=train_embedding_ret_for_summary)
        train_merge_summary_op = tf.summary.merge([
            train_accuracy_scalar, train_cost_scalar,
            train_binary_seg_loss_scalar, train_instance_seg_loss_scalar,
            train_fn_scalar, train_fp_scalar, train_binary_seg_ret_img,
            train_embedding_feats_ret_img
        ])

        # set compute graph node for validation
        val_images, val_binary_labels, val_instance_labels = val_dataset.inputs(
            CFG.TRAIN.VAL_BATCH_SIZE, 1)

        val_compute_ret = val_net.compute_loss(
            input_tensor=val_images,
            binary_label=val_binary_labels,
            instance_label=val_instance_labels,
            name='lanenet_model')
        val_total_loss = val_compute_ret['total_loss']
        val_binary_seg_loss = val_compute_ret['binary_seg_loss']
        val_disc_loss = val_compute_ret['discriminative_loss']
        val_pix_embedding = val_compute_ret['instance_seg_logits']

        val_prediction_logits = val_compute_ret['binary_seg_logits']
        val_prediction_score = tf.nn.softmax(logits=val_prediction_logits)
        val_prediction = tf.argmax(val_prediction_score, axis=-1)

        val_accuracy = evaluate_model_utils.calculate_model_precision(
            val_compute_ret['binary_seg_logits'], val_binary_labels)
        val_fp = evaluate_model_utils.calculate_model_fp(
            val_compute_ret['binary_seg_logits'], val_binary_labels)
        val_fn = evaluate_model_utils.calculate_model_fn(
            val_compute_ret['binary_seg_logits'], val_binary_labels)
        val_binary_seg_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=val_prediction)
        val_embedding_ret_for_summary = evaluate_model_utils.get_image_summary(
            img=val_pix_embedding)

        val_cost_scalar = tf.summary.scalar(name='val_cost',
                                            tensor=val_total_loss)
        val_accuracy_scalar = tf.summary.scalar(name='val_accuracy',
                                                tensor=val_accuracy)
        val_binary_seg_loss_scalar = tf.summary.scalar(
            name='val_binary_seg_loss', tensor=val_binary_seg_loss)
        val_instance_seg_loss_scalar = tf.summary.scalar(
            name='val_instance_seg_loss', tensor=val_disc_loss)
        val_fn_scalar = tf.summary.scalar(name='val_fn', tensor=val_fn)
        val_fp_scalar = tf.summary.scalar(name='val_fp', tensor=val_fp)
        val_binary_seg_ret_img = tf.summary.image(
            name='val_binary_seg_ret', tensor=val_binary_seg_ret_for_summary)
        val_embedding_feats_ret_img = tf.summary.image(
            name='val_embedding_feats_ret',
            tensor=val_embedding_ret_for_summary)
        val_merge_summary_op = tf.summary.merge([
            val_accuracy_scalar, val_cost_scalar, val_binary_seg_loss_scalar,
            val_instance_seg_loss_scalar, val_fn_scalar, val_fp_scalar,
            val_binary_seg_ret_img, val_embedding_feats_ret_img
        ])

        # set optimizer
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.polynomial_decay(
            learning_rate=CFG.TRAIN.LEARNING_RATE,
            global_step=global_step,
            decay_steps=CFG.TRAIN.EPOCHS,
            power=0.9)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=CFG.TRAIN.MOMENTUM).minimize(
                    loss=train_total_loss,
                    var_list=tf.trainable_variables(),
                    global_step=global_step)

    # Set tf model save path
    model_save_dir = 'model/tusimple_lanenet_{:s}'.format(net_flag)
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(
        net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)
    saver = tf.train.Saver()

    # Set tf summary save path
    tboard_save_path = 'tboard/tusimple_lanenet_{:s}'.format(net_flag)
    os.makedirs(tboard_save_path, exist_ok=True)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(
                weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        if net_flag == 'vgg' and weights_path is None:
            load_pretrained_weights(tf.trainable_variables(),
                                    './data/vgg16.npy', sess)

        train_cost_time_mean = []
        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            _, train_c, train_accuracy_figure, train_fn_figure, train_fp_figure, \
                lr, train_summary, train_binary_loss, \
                train_instance_loss, train_embeddings, train_binary_seg_imgs, train_gt_imgs, \
                train_binary_gt_labels, train_instance_gt_labels = \
                sess.run([optimizer, train_total_loss, train_accuracy, train_fn, train_fp,
                          learning_rate, train_merge_summary_op, train_binary_seg_loss,
                          train_disc_loss, train_pix_embedding, train_prediction,
                          train_images, train_binary_labels, train_instance_labels])

            if math.isnan(train_c) or math.isnan(
                    train_binary_loss) or math.isnan(train_instance_loss):
                log.error('cost is: {:.5f}'.format(train_c))
                log.error('binary cost is: {:.5f}'.format(train_binary_loss))
                log.error(
                    'instance cost is: {:.5f}'.format(train_instance_loss))
                return

            if epoch % 100 == 0:
                record_training_intermediate_result(
                    gt_images=train_gt_imgs,
                    gt_binary_labels=train_binary_gt_labels,
                    gt_instance_labels=train_instance_gt_labels,
                    binary_seg_images=train_binary_seg_imgs,
                    pix_embeddings=train_embeddings)
            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info(
                    'Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}'
                    ' lr= {:6f} mean_cost_time= {:5f}s '.format(
                        epoch + 1, train_c, train_binary_loss,
                        train_instance_loss, train_accuracy_figure,
                        train_fp_figure, train_fn_figure, lr,
                        np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            # validation part
            val_c, val_accuracy_figure, val_fn_figure, val_fp_figure, \
                val_summary, val_binary_loss, val_instance_loss, \
                val_embeddings, val_binary_seg_imgs, val_gt_imgs, \
                val_binary_gt_labels, val_instance_gt_labels = \
                sess.run([val_total_loss, val_accuracy, val_fn, val_fp,
                          val_merge_summary_op, val_binary_seg_loss,
                          val_disc_loss, val_pix_embedding, val_prediction,
                          val_images, val_binary_labels, val_instance_labels])

            if math.isnan(val_c) or math.isnan(val_binary_loss) or math.isnan(
                    val_instance_loss):
                log.error('cost is: {:.5f}'.format(val_c))
                log.error('binary cost is: {:.5f}'.format(val_binary_loss))
                log.error('instance cost is: {:.5f}'.format(val_instance_loss))
                return

            if epoch % 100 == 0:
                record_training_intermediate_result(
                    gt_images=val_gt_imgs,
                    gt_binary_labels=val_binary_gt_labels,
                    gt_instance_labels=val_instance_gt_labels,
                    binary_seg_images=val_binary_seg_imgs,
                    pix_embeddings=val_embeddings,
                    flag='val')

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=val_summary, global_step=epoch)

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                log.info(
                    'Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                    'instance_seg_loss= {:6f} accuracy= {:6f} fp= {:6f} fn= {:6f}'
                    ' mean_cost_time= {:5f}s '.format(
                        epoch + 1, val_c, val_binary_loss, val_instance_loss,
                        val_accuracy_figure, val_fp_figure, val_fn_figure,
                        np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=global_step)

    return