Exemple #1
0
def main(argv=None):
    import os
    if os.path.exists(FLAGS.result_path):
        shutil.rmtree(FLAGS.result_path)
    os.makedirs(FLAGS.result_path)

    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    pascal_voc_lut = pascal_segmentation_lut()

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        logits = model.model(input_images,
                             FLAGS.num_classes,
                             is_training=False)
        pred = tf.argmax(logits, dimension=3)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            im_fn_list = get_images()
            for im_fn in im_fn_list:
                im = cv2.imread(im_fn)[:, :, ::-1]
                im_resized, (ratio_h, ratio_w) = resize_image(im, size=32)
                # im_resized = im

                start = time.time()
                pred_re = sess.run([pred],
                                   feed_dict={input_images: [im_resized]})
                pred_re = np.array(np.squeeze(pred_re))
                cv2.imwrite(
                    os.path.join(FLAGS.result_path, os.path.basename(im_fn)),
                    pred_re)

                # img = visualize_segmentation_adaptive(pred_re, pascal_voc_lut)
                _diff_time = time.time() - start
                # cv2.imwrite(os.path.join(FLAGS.result_path, os.path.basename(im_fn)), img)

                print('{}: cost {:.0f}ms').format(im_fn, _diff_time * 1000)
Exemple #2
0
def main(argv=None):
    if os.path.exists(FLAGS.result_path):
        shutil.rmtree(FLAGS.result_path)
    os.makedirs(FLAGS.result_path)

    pascal_voc_lut = pascal_segmentation_lut()

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        logits = model.model(input_images, is_training=False)
        pred = tf.argmax(logits, axis=3)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            # 加载模型
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            image_name_list = get_images()
            for image_name in image_name_list:
                im = np.asarray(Image.open(image_name))[:, :, 0:3][:, :, ::-1]
                print(im.shape)
                im_resized, (ratio_h, ratio_w) = resize_image(im, size=32)

                # 预测
                start = time.time()
                pred_re = sess.run(pred,
                                   feed_dict={input_images: [im_resized]})

                # 保存
                img = visualize_segmentation_adaptive(pred_re[0],
                                                      pascal_voc_lut)
                Image.fromarray(img).convert("RGB").save(
                    os.path.join(FLAGS.result_path,
                                 os.path.basename(image_name)))
                print('{}: cost {:.0f}ms'.format(image_name,
                                                 time.time() - start))
    pass
Exemple #3
0
def main(argv=None):
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    pascal_voc_lut = pascal_segmentation_lut()

    filename_queue = tf.train.string_input_producer([FLAGS.test_data_path],
                                                    num_epochs=1)
    image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(
        filename_queue)

    image_batch_tensor = tf.expand_dims(image, axis=0)
    annotation_batch_tensor = tf.expand_dims(annotation, axis=0)

    input_image_shape = tf.shape(image_batch_tensor)
    image_height_width = input_image_shape[1:3]
    image_height_width_float = tf.to_float(image_height_width)
    image_height_width_multiple = tf.to_int32(
        tf.round(image_height_width_float / 32) * 32)

    image_batch_tensor = tf.image.resize_images(image_batch_tensor,
                                                image_height_width_multiple)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    logits = model.model(image_batch_tensor, is_training=False)
    pred = tf.argmax(logits, dimension=3)
    pred = tf.expand_dims(pred, 3)
    pred = tf.image.resize_nearest_neighbor(images=pred,
                                            size=image_height_width)
    annotation_batch_tensor = tf.image.resize_nearest_neighbor(
        images=annotation_batch_tensor, size=image_height_width)

    pred = tf.reshape(pred, [
        -1,
    ])
    gt = tf.reshape(annotation_batch_tensor, [
        -1,
    ])
    temp = tf.less_equal(gt, FLAGS.num_classes - 1)
    weights = tf.cast(temp, tf.int32)
    gt = tf.where(temp, gt, tf.cast(temp, tf.uint8))
    acc, acc_update_op = tf.contrib.metrics.streaming_accuracy(pred,
                                                               gt,
                                                               weights=weights)
    miou, miou_update_op = tf.contrib.metrics.streaming_mean_iou(
        pred, gt, num_classes=FLAGS.num_classes, weights=weights)

    with tf.get_default_graph().as_default():
        global_vars_init_op = tf.global_variables_initializer()
        local_vars_init_op = tf.local_variables_initializer()
        init = tf.group(local_vars_init_op, global_vars_init_op)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(init)
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            print('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            for i in range(1449):
                start = time.time()
                image_np, annotation_np, pred_np, tmp_acc, tmp_miou = sess.run(
                    [image, annotation, pred, acc_update_op, miou_update_op])
                _diff_time = time.time() - start
                print('{}: cost {:.0f}ms').format(i, _diff_time * 1000)
                #upsampled_predictions = pred_np.squeeze()
                #plt.imshow(image_np)
                #plt.show()
                #visualize_segmentation_adaptive(upsampled_predictions, pascal_voc_lut)
            acc_res = sess.run(acc)
            miou_res = sess.run(miou)
            print("Pascal VOC 2012 validation dataset pixel accuracy: " +
                  str(acc_res))
            print("Pascal VOC 2012 validation dataset Mean IoU: " +
                  str(miou_res))

    coord.request_stop()
    coord.join(threads)
Exemple #4
0
def main(argv=None):
    # 类别字典
    pascal_voc_lut = pascal_segmentation_lut()
    # 类别键值
    class_labels = list(pascal_voc_lut.keys())
    # 类别与颜色的对应关系
    with open('data/color_map.pkl', 'rb') as f:
        color_map = pickle.load(f)

    # 日志目录
    style_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    os.makedirs(FLAGS.logs_path + style_time)

    # 模型保存目录
    if not os.path.exists(FLAGS.checkpoint_path):
        os.makedirs(FLAGS.checkpoint_path)

    # 文件队列
    filename_queue = tf.train.string_input_producer([FLAGS.training_data_path],
                                                    num_epochs=1000)
    # 解码tf record数据
    image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(
        filename_queue)
    # 随机左右翻转
    image, annotation = flip_randomly_left_right_image_with_annotation(
        image, annotation)
    # 随机色彩变换
    image = distort_randomly_image_color(image)
    # 随机缩放
    image_train_size = [FLAGS.train_size, FLAGS.train_size]
    resize_image, resize_annotation = scale_randomly_image_with_annotation_with_fixed_size_output(
        image, annotation, image_train_size)
    # 在读数据的时候,对注解进行了升维。现在进行降维
    resize_annotation = tf.squeeze(resize_annotation)
    # 转成批次
    image_batch, annotation_batch = tf.train.shuffle_batch(
        [resize_image, resize_annotation],
        batch_size=FLAGS.batch_size,
        capacity=1000,
        num_threads=4,
        min_after_dequeue=500)

    # 学习率和全局步数
    learning_rate = tf.Variable(FLAGS.learning_rate, trainable=False)
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)

    # 得到loss和预测
    total_loss, model_loss, output_pred = tower_loss(image_batch,
                                                     annotation_batch,
                                                     class_labels)

    # 1.优化损失:loos updates
    gradient_updates_op = tf.train.AdamOptimizer(learning_rate).minimize(
        total_loss, global_step=global_step)

    # 2.滑动平均:moving average updates
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    # 3.BN参数更新:batch norm updates
    batch_norm_updates_op = tf.group(
        *tf.get_collection(tf.GraphKeys.UPDATE_OPS))

    # 4.合并训练节点
    with tf.control_dependencies(
        [variables_averages_op, gradient_updates_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)

    summary_writer = tf.summary.FileWriter(FLAGS.logs_path + style_time,
                                           tf.get_default_graph())
    # summary:学习率
    tf.summary.scalar('learning_rate', learning_rate)
    # summary:图片
    log_image_data = tf.placeholder(tf.uint8, [None, None, 3])
    log_image_name = tf.placeholder(tf.string)
    log_image = tf.summary.image(log_image_name,
                                 tf.expand_dims(log_image_data, 0))
    # 合并summary
    summary_op = tf.summary.merge_all()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        # 恢复模型
        restore_step = 0
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
            restore_step = int(ckpt.split('.')[0].split('_')[-1])
            saver.restore(sess, ckpt)
        elif FLAGS.pre_train_model_path is not None:
            # 加载预训练模型
            # Returns a function that assigns specific variables from a checkpoint.
            variable_restore_op = slim.assign_from_checkpoint_fn(
                FLAGS.pre_train_model_path,
                slim.get_trainable_variables(),
                ignore_missing_vars=True)
            variable_restore_op(sess)
            pass

        start = time.time()
        coord = tf.train.Coordinator()
        # 启动队列
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                for step in range(restore_step, FLAGS.max_steps):
                    # 衰减学习率
                    if step != 0 and step % FLAGS.decay_steps == 0:
                        sess.run(
                            tf.assign(learning_rate,
                                      learning_rate.eval() * FLAGS.decay_rate))
                        pass

                    # 执行损失和训练
                    ml, tl, _ = sess.run([model_loss, total_loss, train_op])

                    # 损失发散(不收敛)
                    if np.isnan(tl):
                        print('Loss diverged, stop training')
                        break

                    # 计时并打印信息
                    if step % 10 == 0:
                        print(
                            'Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.3f} seconds/step, lr: {:.7f}'
                            .format(step, ml, tl, (time.time() - start) / 10,
                                    learning_rate.eval()))
                        start = time.time()
                        pass

                    # 保存模型
                    if (step + 1) % FLAGS.save_checkpoint_steps == 0:
                        filename = ('RefineNet' +
                                    '_step_{:d}'.format(step + 1) + '.ckpt')
                        filename = os.path.join(FLAGS.checkpoint_path,
                                                filename)
                        saver.save(sess, filename)
                        print('Write model to: {:s}'.format(filename))

                    # 保存summary
                    if step % FLAGS.save_summary_steps == 0:
                        # 再运行一次
                        img_split, seg_split, pred = sess.run(
                            [image_batch, annotation_batch, output_pred])

                        # 降维并取第0个
                        img_split = np.squeeze(img_split)[0]
                        seg_split = np.squeeze(seg_split)[0]
                        pred_split = np.squeeze(pred)[0]

                        # 注解图片
                        color_seg = np.zeros(
                            (seg_split.shape[0], seg_split.shape[1], 3))
                        for i in range(seg_split.shape[0]):
                            for j in range(seg_split.shape[1]):
                                color_seg[i, j, :] = color_map[str(
                                    seg_split[i][j])]

                        # 预测图片
                        color_pred = np.zeros(
                            (pred_split.shape[0], pred_split.shape[1], 3))
                        for i in range(pred_split.shape[0]):
                            for j in range(pred_split.shape[1]):
                                color_pred[i, j, :] = color_map[str(
                                    pred_split[i][j])]

                        write_img = np.hstack(
                            (img_split, color_seg, color_pred))
                        _, summary_str = sess.run(
                            [train_op, summary_op, log_image],
                            feed_dict={
                                log_image_name: ('%06d' % step),
                                log_image_data: write_img
                            })
                        # 写入summary
                        summary_writer.add_summary(summary_str,
                                                   global_step=step)

                    pass
                pass
            pass
        except tf.errors.OutOfRangeError:
            print('finish')
        finally:
            coord.request_stop()
        coord.join(threads)
    pass
Exemple #5
0
def main(argv=None):
    gpus = range(len(FLAGS.gpu_list.split(',')))
    pascal_voc_lut = pascal_segmentation_lut()
    class_labels = list(pascal_voc_lut.keys())
    with open('data/color_map', 'rb') as f:
        color_map = pickle.load(f, encoding="bytes")

    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    now = datetime.datetime.now()
    StyleTime = now.strftime("%Y-%m-%d-%H-%M-%S")
    os.makedirs(FLAGS.logs_path + StyleTime)
    if not os.path.exists(FLAGS.checkpoint_path):
        os.makedirs(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            if os.path.exists(FLAGS.checkpoint_path):
                shutil.rmtree(FLAGS.checkpoint_path)
                os.makedirs(FLAGS.checkpoint_path)

    filename_queue = tf.train.string_input_producer([FLAGS.training_data_path],
                                                    num_epochs=1000)
    image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(
        filename_queue)

    image, annotation = flip_randomly_left_right_image_with_annotation(
        image, annotation)
    image = distort_randomly_image_color(image)

    image_train_size = [FLAGS.train_size, FLAGS.train_size]
    resized_image, resized_annotation = scale_randomly_image_with_annotation_with_fixed_size_output(
        image, annotation, image_train_size)
    resized_annotation = tf.squeeze(resized_annotation)

    image_batch, annotation_batch = tf.train.shuffle_batch(
        [resized_image, resized_annotation],
        batch_size=FLAGS.batch_size * len(gpus),
        capacity=1000,
        num_threads=4,
        min_after_dequeue=500)

    # split
    input_images_split = tf.split(image_batch, len(gpus))
    input_segs_split = tf.split(annotation_batch, len(gpus))

    learning_rate = tf.Variable(FLAGS.learning_rate, trainable=False)
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    # add summary
    tf.summary.scalar('learning_rate', learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)

    tower_grads = []
    reuse_variables = None
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_images_split[i]
                isms = input_segs_split[i]
                total_loss, model_loss, output_pred = tower_loss(
                    iis, isms, class_labels, reuse_variables)
                batch_norm_updates_op = tf.group(
                    *tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True
                grads = opt.compute_gradients(total_loss)
                tower_grads.append(grads)

    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    log_image, log_image_data, log_image_name = build_image_summary()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
    summary_writer = tf.summary.FileWriter(FLAGS.logs_path + StyleTime,
                                           tf.get_default_graph())

    # if FLAGS.pretrained_model_path is not None:
    #     variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
    #                                                          slim.get_trainable_variables(),
    #                                                          ignore_missing_vars=True)

    global_vars_init_op = tf.global_variables_initializer()
    local_vars_init_op = tf.local_variables_initializer()
    init = tf.group(local_vars_init_op, global_vars_init_op)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        restore_step = 0
        if FLAGS.restore:
            sess.run(init)
            print('continue training from previous checkpoint')
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
            restore_step = int(ckpt.split('.')[0].split('_')[-1])
            saver.restore(sess, ckpt)
        else:
            sess.run(init)
            # if FLAGS.pretrained_model_path is not None:
            #     variable_restore_op(sess)

        start = time.time()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                for step in range(restore_step, FLAGS.max_steps):
                    if step != 0 and step % FLAGS.decay_steps == 0:
                        sess.run(
                            tf.assign(learning_rate,
                                      learning_rate.eval() * FLAGS.decay_rate))

                    ml, tl, _ = sess.run([model_loss, total_loss, train_op])
                    if np.isnan(tl):
                        print('Loss diverged, stop training')
                        break
                    if step % 10 == 0:
                        avg_time_per_step = (time.time() - start) / 10
                        start = time.time()
                        print('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.3f} seconds/step, lr: {:.7f}').\
                            format(step, ml, tl, avg_time_per_step,learning_rate.eval())

                    if (step + 1) % FLAGS.save_checkpoint_steps == 0:
                        filename = ('RefineNet' +
                                    '_step_{:d}'.format(step + 1) + '.ckpt')
                        filename = os.path.join(FLAGS.checkpoint_path,
                                                filename)
                        saver.save(sess, filename)
                        print('Write model to: {:s}'.format(filename))

                    if step % FLAGS.save_summary_steps == 0:
                        _, tl, summary_str = sess.run(
                            [train_op, total_loss, summary_op])
                        summary_writer.add_summary(summary_str,
                                                   global_step=step)

                    if step % FLAGS.save_image_steps == 0:
                        log_image_name_str = ('%06d' % step)
                        img_split, seg_split, pred = sess.run(
                            [iis, isms, output_pred])

                        img_split = np.squeeze(img_split)[0]
                        seg_split = np.squeeze(seg_split)[0]
                        pred = np.squeeze(pred)[0]

                        #img_split=cv2.resize(img_split,(128,128))

                        color_seg = np.zeros(
                            (seg_split.shape[0], seg_split.shape[1], 3))
                        for i in range(seg_split.shape[0]):
                            for j in range(seg_split.shape[1]):
                                color_seg[i, j, :] = color_map[str(
                                    seg_split[i][j])]

                        color_pred = np.zeros(
                            (pred.shape[0], pred.shape[1], 3))
                        for i in range(pred.shape[0]):
                            for j in range(pred.shape[1]):
                                color_pred[i,
                                           j, :] = color_map[str(pred[i][j])]

                        write_img = np.hstack((color_seg, color_pred))
                        log_image_summary_op = sess.run(log_image,feed_dict={log_image_name: log_image_name_str, \
                                                                   log_image_data: write_img})
                        summary_writer.add_summary(log_image_summary_op,
                                                   global_step=step)
        except tf.errors.OutOfRangeError:
            print('finish')
        finally:
            coord.request_stop()
        coord.join(threads)
Exemple #6
0
from models.fcn_8s import FCN_8s, extract_vgg_16_mapping_without_fc8
from utils.pascal_voc import pascal_segmentation_lut
from utils.training import get_valid_logits_and_labels
from utils.augmentation import (
    distort_randomly_image_color,
    flip_randomly_left_right_image_with_annotation,
    scale_randomly_image_with_annotation_with_fixed_size_output)

slim = tf.contrib.slim
vgg_checkpoint_path = '/home/jochiu/Fully_CNN/vgg_16.ckpt'
log_folder = '/home/jochiu/Fully_CNN/log'

image_train_size = [384, 384]
number_of_classes = 21
tfrecord_filename = 'pascal_augmented_train.tfrecords'
pascal_voc_lut = pascal_segmentation_lut()
class_labels = pascal_voc_lut.keys()

filename_queue = tf.train.string_input_producer([tfrecord_filename],
                                                num_epochs=10)

image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(
    filename_queue)

image, annotation = flip_randomly_left_right_image_with_annotation(
    image, annotation)

resized_image, resized_annotation = scale_randomly_image_with_annotation_with_fixed_size_output(
    image, annotation, image_train_size)

resized_annotation = tf.squeeze(resized_annotation)
Exemple #7
0
def main(argv=None):
    gpus = range(len(FLAGS.gpu_list.split(',')))
    pascal_voc_lut = pascal_segmentation_lut()
    class_labels = pascal_voc_lut.keys()

    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    now = datetime.datetime.now()
    StyleTime = now.strftime("%Y-%m-%d-%H-%M-%S")
    os.makedirs(FLAGS.logs_path + StyleTime)
    if not os.path.exists(FLAGS.checkpoint_path):
        os.makedirs(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            if os.path.exists(FLAGS.checkpoint_path):
                shutil.rmtree(FLAGS.checkpoint_path)
                os.makedirs(FLAGS.checkpoint_path)

    #input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
    #input_segs = tf.placeholder(tf.float32, shape=[None, None,None, 1], name='input_segs')

    filename_queue = tf.train.string_input_producer([FLAGS.training_data_path],
                                                    num_epochs=1000)
    image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(
        filename_queue)

    image, annotation = flip_randomly_left_right_image_with_annotation(
        image, annotation)
    image = distort_randomly_image_color(image)

    image_train_size = [FLAGS.train_size, FLAGS.train_size]
    resized_image, resized_annotation = scale_randomly_image_with_annotation_with_fixed_size_output(
        image, annotation, image_train_size)
    resized_annotation = tf.squeeze(resized_annotation)

    image_batch, annotation_batch = tf.train.shuffle_batch(
        [resized_image, resized_annotation],
        batch_size=FLAGS.batch_size * len(gpus),
        capacity=1000,
        num_threads=4,
        min_after_dequeue=500)

    # split
    input_images_split = tf.split(image_batch, len(gpus))
    input_segs_split = tf.split(annotation_batch, len(gpus))

    # 定义损失函数、学习率、滑动平均操作以及训练过程。
    learning_rate = tf.Variable(FLAGS.learning_rate, trainable=False)
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    # add summary
    tf.summary.scalar('learning_rate', learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)

    tower_grads = []
    reuse_variables = None
    iis = input_images_split[i]
    isms = input_segs_split[i]
    total_loss, model_loss, output_pred = tower_loss(iis, isms, class_labels,
                                                     reuse_variables)
    reuse_variables = True
    grads = opt.compute_gradients(total_loss)
    tower_grads.append(grads)

    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies([variables_averages_op, apply_gradient_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
    summary_writer = tf.summary.FileWriter(FLAGS.logs_path + StyleTime,
                                           tf.get_default_graph())

    if FLAGS.pretrained_model_path is not None:
        variable_restore_op = slim.assign_from_checkpoint_fn(
            FLAGS.pretrained_model_path,
            slim.get_trainable_variables(),
            ignore_missing_vars=True)

    global_vars_init_op = tf.global_variables_initializer()
    local_vars_init_op = tf.local_variables_initializer()
    init = tf.group(local_vars_init_op, global_vars_init_op)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        restore_step = 0
        if FLAGS.restore:
            sess.run(init)
            print('continue training from previous checkpoint')
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
            restore_step = int(ckpt.split('.')[0].split('_')[-1])
            saver.restore(sess, ckpt)
        else:
            sess.run(init)
            if FLAGS.pretrained_model_path is not None:
                variable_restore_op(sess)

        start = time.time()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                for step in range(restore_step, FLAGS.max_steps):

                    if step != 0 and step % FLAGS.decay_steps == 0:
                        sess.run(
                            tf.assign(learning_rate,
                                      learning_rate.eval() * FLAGS.decay_rate))

                    ml, tl, _ = sess.run([model_loss, total_loss, train_op])
                    if np.isnan(tl):
                        print('Loss diverged, stop training')
                        break
                    if step % 10 == 0:
                        avg_time_per_step = (time.time() - start) / 10
                        start = time.time()
                        print('Step {:06d}, model loss {:.6f}, total loss {:.6f}, {:.3f} seconds/step, lr: {:.10f}').\
                            format(step, ml, tl, avg_time_per_step,learning_rate.eval())

                    if (step + 1) % FLAGS.save_checkpoint_steps == 0:
                        filename = ('RefineNet' +
                                    '_step_{:d}'.format(step + 1) + '.ckpt')
                        filename = os.path.join(FLAGS.checkpoint_path,
                                                filename)
                        saver.save(sess, filename)
                        print('Write model to: {:s}'.format(filename))

                    if step % FLAGS.save_summary_steps == 0:
                        _, tl, summary_str = sess.run(
                            [train_op, total_loss, summary_op])
                        summary_writer.add_summary(summary_str,
                                                   global_step=step)

        except tf.errors.OutOfRangeError:
            print('finish')
        finally:
            coord.request_stop()
        coord.join(threads)