示例#1
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load reader.
    with tf.name_scope("create_inputs"):
        reader = ImageReader(args.data_dir, args.data_list, input_size,
                             RANDOM_SCALE, coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = DeepLabLFOVModel(args.weights_path)

    # Define the loss and optimisation parameters.
    loss = net.loss(image_batch, label_batch)
    optimiser = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimiser.minimize(loss, var_list=trainable)

    tf.summary.scalar("lfov_loss", loss)

    pred = net.preds(image_batch)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.initialize_all_variables()

    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=trainable, max_to_keep=40)
    if args.restore_from is not None:
        load(saver, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    summary_merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(args.snapshot_dir,
                                           graph=tf.get_default_graph())

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()

        if step % args.save_pred_every == 0:
            loss_value, images, labels, preds, _ = sess.run(
                [loss, image_batch, label_batch, pred, optim])
            fig, axes = plt.subplots(args.save_num_images, 3, figsize=(16, 12))
            for i in range(args.save_num_images):
                axes.flat[i * 3].set_title('data')
                axes.flat[i * 3].imshow(
                    (images[i] + IMG_MEAN)[:, :, ::-1].astype(np.uint8))

                axes.flat[i * 3 + 1].set_title('mask')
                axes.flat[i * 3 + 1].imshow(decode_labels(labels[i, :, :, 0]))

                axes.flat[i * 3 + 2].set_title('pred')
                axes.flat[i * 3 + 2].imshow(decode_labels(preds[i, :, :, 0]))
            plt.savefig(args.save_dir + str(start_time) + ".png")
            plt.close(fig)
            save(saver, sess, args.snapshot_dir, step)
        else:
            loss_value, summary, _ = sess.run([loss, summary_merged, optim])
            summary_writer.add_summary(summary, step)
        duration = time.time() - start_time
        print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(
            step, loss_value, duration))
    coord.request_stop()
    coord.join(threads)
示例#2
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    # Create queue coordinator.
    coord = tf.train.Coordinator()
    
    # Load reader.
    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            args.data_dir,
            args.data_list,
            input_size,
            RANDOM_SCALE,
            coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)
    
    # Create network.
    net = DeepLabLFOVModel(args.weights_path)

    # Define the loss and optimisation parameters.
    loss = net.loss(image_batch, label_batch)
    optimiser = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimiser.minimize(loss, var_list=trainable)

    tf.summary.scalar("lfov_loss",loss)
    
    pred = net.preds(image_batch)
    
    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.initialize_all_variables()
    
    sess.run(init)
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=trainable, max_to_keep=40)
    if args.restore_from is not None:
        load(saver, sess, args.restore_from)
    
    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    summary_merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(args.snapshot_dir,graph=tf.get_default_graph())

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()
        
        if step % args.save_pred_every == 0:
            loss_value, images, labels, preds, _ = sess.run([loss, image_batch, label_batch, pred, optim])
            fig, axes = plt.subplots(args.save_num_images, 3, figsize = (16, 12))
            for i in range(args.save_num_images):
                axes.flat[i * 3].set_title('data')
                axes.flat[i * 3].imshow((images[i] + IMG_MEAN)[:, :, ::-1].astype(np.uint8))

                axes.flat[i * 3 + 1].set_title('mask')
                axes.flat[i * 3 + 1].imshow(decode_labels(labels[i, :, :, 0]))

                axes.flat[i * 3 + 2].set_title('pred')
                axes.flat[i * 3 + 2].imshow(decode_labels(preds[i, :, :, 0]))
            plt.savefig(args.save_dir + str(start_time) + ".png")
            plt.close(fig)
            save(saver, sess, args.snapshot_dir, step)
        else:
            loss_value, summary , _ = sess.run([loss,summary_merged, optim])
            summary_writer.add_summary(summary,step)
        duration = time.time() - start_time
        print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))
    coord.request_stop()
    coord.join(threads)
示例#3
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()
    
    # Create queue coordinator.
    coord = tf.train.Coordinator()
    
    # Load reader.
    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            args.data_dir,
            args.data_list,
            input_size=None,
            random_scale=False,
            coord=coord)
        image, label = reader.image, reader.label
    image_batch, label_batch = tf.expand_dims(image, dim=0), tf.expand_dims(label, dim=0) # Add the batch dimension.
    # Create network.
    net = DeepLabLFOVModel(args.weights_path)

    # Which variables to load.
    trainable = tf.trainable_variables()
    
    # Predictions.
    pred = net.preds(image_batch)
    
    # mIoU
    mIoU, update_op = tf.contrib.metrics.streaming_mean_iou(pred, label_batch, num_classes=21) 
    
    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.initialize_all_variables()
    
    sess.run(init)
    sess.run(tf.initialize_local_variables())
    
    # Load weights.
    saver = tf.train.Saver(var_list=trainable)
    if args.restore_from is not None:
        load(saver, sess, args.restore_from)
    
    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    
    if not os.path.exists(args.save_dir):
      os.makedirs(args.save_dir)
      
    # Iterate over images.
    for step in range(args.num_steps):
        #mIoU_value = sess.run([mIoU])
        #_ = update_op.eval(session=sess)
        preds, _ = sess.run([pred, update_op])
        
        if args.save_dir is not None:
            img = decode_labels(preds[0, :, :, 0])
            im = Image.fromarray(img)
            im.save(args.save_dir + str(step) + '.png')
        if step % 100 == 0:
            print('step {:d} \t'.format(step))
    print('Mean IoU: {:.3f}'.format(mIoU.eval(session=sess)))
    coord.request_stop()
    coord.join(threads)
示例#4
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
  
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    # Create queue coordinator.
    coord = tf.train.Coordinator()
    
    # Load reader.
    with tf.name_scope("create_train_inputs"):
        reader_train = ImageReader(
            args.data_dir,
            args.data_train_list,
            input_size,
            RANDOM_SCALE,
            coord)
        image_batch_train, label_batch_train = reader_train.dequeue(args.batch_size)

    with tf.name_scope("create_val_inputs"):
        reader_val = ImageReader(
            args.data_dir,
            args.data_val_list,
            input_size,
            False,
            coord)
        image_batch_val, label_batch_val = reader_val.dequeue(args.batch_size)

    is_training = tf.placeholder(tf.bool,shape = [],name = 'stauts')
    image_batch,label_batch = tf.cond(is_training,lambda: (image_batch_train,label_batch_train),lambda: (image_batch_val,label_batch_val))

    # Create network.
    net = DeepLabLFOVModel(args.weights_path)

    # Define the loss and optimisation parameters.
    global_step = tf.Variable(0, trainable=False)
    step_thre=tf.constant(args.step_thre)
    loss,recall,precision = net.loss(image_batch, label_batch,global_step,step_thre)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable,global_step=global_step)
    
    pred = net.preds(image_batch)
    # mIoU
    mIoU, update_op = tf.metrics.mean_iou(label_batch, pred, num_classes=2) 
    mIoU_vali,update_op_vali=tf.metrics.mean_iou(label_batch, pred, num_classes=2)
    
    merged = tf.summary.merge_all() 
    if os.path.exists(args.log_dir):
    	shutil.rmtree(args.log_dir)
    summary_writer = tf.summary.FileWriter(args.log_dir,graph = tf.get_default_graph())

    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = False
    # config.gpu_options.per_process_gpu_memory_fraction = 0.8
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    
    sess.run(init)
    sess.run(tf.local_variables_initializer())
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=trainable, max_to_keep=40)

    if args.restore_from is not None:
        load(saver, sess, args.restore_from)
    
    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    
    if os.path.exists(args.save_dir):
        shutil.rmtree(args.save_dir)
    os.makedirs(args.save_dir)

    if os.path.exists(args.snapshot_dir):
        shutil.rmtree(args.snapshot_dir)
   
    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()
        
        if step % args.save_pred_every == 0:
            loss_value,images, labels, preds, _ = sess.run([loss,image_batch, label_batch, pred, optim],feed_dict={is_training:True})
            fig, axes = plt.subplots(args.save_num_images, 3, figsize = (16, 12))
            for i in xrange(args.save_num_images):
                axes.flat[i * 3].set_title('data')
                axes.flat[i * 3].imshow((images[i] + IMG_MEAN)[:, :, ::-1].astype(np.uint8))

                axes.flat[i * 3 + 1].set_title('mask')
                axes.flat[i * 3 + 1].imshow(decode_labels(labels[i, :, :, 0]))

                axes.flat[i * 3 + 2].set_title('pred')
                axes.flat[i * 3 + 2].imshow(decode_labels(preds[i, :, :, 0]))
            plt.savefig(args.save_dir + str(start_time) + ".png")
            plt.close(fig)
            save(saver, sess, args.snapshot_dir, step)
            #validation
            for i in range(VAL_LOOP):
                start_time_vali=time.time()
                loss_vali,images, labels, preds,recall_vali,precision_vali,mIoU_vali_value,_ = sess.run(\
                    [loss,image_batch, label_batch, pred,recall,precision,mIoU_vali,update_op_vali],feed_dict={is_training:False})
                duration_vali=time.time()-start_time_vali

                print('validation step {:<5d}\tbatch:{:<3d}, recall: {:.3f}, precision: {:.3f}, mIoU: {:.3f}'.format(\
                    step,i,recall_vali,precision_vali,mIoU_vali_value))
                print('validation step {:<5d}\tbatch:{:<3d}, loss = {:.5f}, ({:.5f} sec/batch)'.format(step,i,loss_vali,duration_vali))

                for j in range(BATCH_SIZE):
                    fig, axes = plt.subplots(1, 3, figsize = (16, 12))
                    axes.flat[0].set_title('data')
                    axes.flat[0].imshow((images[j] + IMG_MEAN)[:, :, ::-1].astype(np.uint8))

                    axes.flat[1].set_title('mask')
                    axes.flat[1].imshow(decode_labels(labels[j, :, :, 0]))

                    axes.flat[2].set_title('pred')
                    axes.flat[2].imshow(decode_labels(preds[j, :, :, 0]))

                plt.savefig(args.save_dir + str(start_time) +'_'+str(i)+"test.png")
                plt.close(fig)

        else:
            loss_value, _ ,summary,recall_value,precision_value,mIoU_value,_= sess.run(\
                [loss,optim,merged,recall,precision,mIoU,update_op],feed_dict={is_training:True})
            print('step {:<5d}\trecall: {:.3f}, precision: {:.3f}, mIoU: {:.3f}'.format(step,recall_value,precision_value,mIoU_value))

            summary_writer.add_summary(summary,step)

        duration = time.time() - start_time
        print('step {:<5d}\tloss = {:.5f}, ({:.5f} sec/step)'.format(step,loss_value,duration))


    coord.request_stop()
    coord.join(threads)
示例#5
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load reader.
    with tf.name_scope("create_inputs"):
        reader = ImageReader(args.data_dir, args.data_list, input_size,
                             RANDOM_SCALE, coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = DeepLabLFOVModel()

    # Define the loss and optimisation parameters.
    main_loss_1,attention_loss_1, pre_upscaled_1, output_attention_map_1,attention_map_1_predicted,predict_3d_1,\
    main_loss_2,attention_loss_2, pre_upscaled_2,output_attention_map_2,attention_map_2_predicted,predict_3d_2, \
    main_loss_3,attention_loss_3, pre_upscaled_3, output_attention_map_3 ,attention_map_3_predicted,predict_3d_3\
        = net.loss(image_batch, label_batch)

    main_loss = main_loss_1 + main_loss_2 + main_loss_3
    attention_loss = attention_loss_1 + attention_loss_2 + attention_loss_3

    learning_rate = tf.placeholder(tf.float32, shape=[])
    att_learning_rate = tf.placeholder(tf.float32, shape=[])
    trainable = tf.trainable_variables()

    frozen_trainalbe = [u'conv1', u'conv2', u'conv3', u'conv4']
    final_trainable = trainable
    for f in frozen_trainalbe:
        final_trainable = [x for x in final_trainable if f not in x.name]
    print("====final_trainable shape check====")
    for v in final_trainable:
        print("{}:  {}".format(v.name, v.get_shape()))

    #optim = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9).minimize(main_loss, var_list=trainable)
    optim = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                       momentum=0.9)
    att_optim = tf.train.MomentumOptimizer(learning_rate=att_learning_rate,
                                           momentum=0.9).minimize(
                                               attention_loss,
                                               var_list=final_trainable)

    grads_and_vars_tf_style = optim.compute_gradients(main_loss,
                                                      tf.trainable_variables())
    train_tf_style = optim.apply_gradients(grads_and_vars_tf_style)

    pred_result = net.preds(image_batch)

    def convert(image):
        return tf.image.convert_image_dtype(image,
                                            dtype=tf.uint8,
                                            saturate=True)

    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name + "/values", var)

    images_summary = tf.py_func(inv_preprocess, [image_batch, SAVE_NUM_IMAGES],
                                tf.uint8)
    labels_summary = tf.py_func(decode_labels_by_batch,
                                [label_batch, SAVE_NUM_IMAGES], tf.uint8)
    preds_1_summary = tf.py_func(decode_labels_by_batch,
                                 [pre_upscaled_1, SAVE_NUM_IMAGES], tf.uint8)
    preds_2_summary = tf.py_func(decode_labels_by_batch,
                                 [pre_upscaled_2, SAVE_NUM_IMAGES], tf.uint8)
    preds_3_summary = tf.py_func(decode_labels_by_batch,
                                 [pre_upscaled_3, SAVE_NUM_IMAGES], tf.uint8)

    att_1_summary = tf.py_func(single_channel_process,
                               [output_attention_map_1, SAVE_NUM_IMAGES],
                               tf.uint8)
    att_2_summary = tf.py_func(single_channel_process,
                               [output_attention_map_2, SAVE_NUM_IMAGES],
                               tf.uint8)
    att_3_summary = tf.py_func(single_channel_process,
                               [output_attention_map_3, SAVE_NUM_IMAGES],
                               tf.uint8)

    predict_3d_1_summary = tf.py_func(single_channel_process,
                                      [predict_3d_1, SAVE_NUM_IMAGES],
                                      tf.uint8)
    predict_3d_2_summary = tf.py_func(single_channel_process,
                                      [predict_3d_2, SAVE_NUM_IMAGES],
                                      tf.uint8)
    predict_3d_3_summary = tf.py_func(single_channel_process,
                                      [predict_3d_3, SAVE_NUM_IMAGES],
                                      tf.uint8)

    attention_map_1_predicted_summary = tf.py_func(
        single_channel_process, [attention_map_1_predicted, SAVE_NUM_IMAGES],
        tf.uint8)
    attention_map_2_predicted_summary = tf.py_func(
        single_channel_process, [attention_map_2_predicted, SAVE_NUM_IMAGES],
        tf.uint8)
    attention_map_3_predicted_summary = tf.py_func(
        single_channel_process, [attention_map_3_predicted, SAVE_NUM_IMAGES],
        tf.uint8)

    # define Summary
    summary_list = []
    for var in tf.trainable_variables():
        summary_list.append(tf.summary.histogram(var.op.name + "/values", var))

    #summary
    with tf.name_scope("loss_summary"):
        summary_list.append(tf.summary.scalar("main_loss", main_loss))
        summary_list.append(tf.summary.scalar("attention_loss",
                                              attention_loss))
        summary_list.append(tf.summary.scalar("loss_1", main_loss_1))
        summary_list.append(tf.summary.scalar("loss_2", main_loss_2))
        summary_list.append(tf.summary.scalar("loss_3", main_loss_3))

    with tf.name_scope("image_summary"):
        #origin_summary = tf.summary.image("origin", images_summary)
        #label_summary = tf.summary.image("label", labels_summary)
        summary_list.append(
            tf.summary.image(
                'total_image',
                tf.concat([
                    images_summary, labels_summary, preds_1_summary,
                    att_1_summary, preds_2_summary, att_2_summary,
                    preds_3_summary, att_3_summary
                ], 2),
                max_outputs=SAVE_NUM_IMAGES))

        summary_list.append(
            tf.summary.image(
                'attention_image',
                tf.concat([
                    images_summary, labels_summary, att_1_summary,
                    attention_map_1_predicted_summary, att_2_summary,
                    attention_map_2_predicted_summary, att_3_summary,
                    attention_map_3_predicted_summary
                ], 2),
                max_outputs=SAVE_NUM_IMAGES))

        summary_list.append(
            tf.summary.image(
                'confidence_map',
                tf.concat([
                    images_summary, labels_summary, predict_3d_1_summary,
                    predict_3d_2_summary, predict_3d_3_summary
                ], 2),
                max_outputs=SAVE_NUM_IMAGES))

    merged_summary_op = tf.summary.merge(summary_list)

    summary_writer = tf.summary.FileWriter(args.summay_dir, sess.graph)

    # check the shape
    print("====global_variables shape check====")
    for v in tf.global_variables():
        print("{}:  {}".format(v.name, v.get_shape()))
    print("====trainable_variables shape check====")
    for v in tf.trainable_variables():
        print("{}:  {}".format(v.name, v.get_shape()))

    # don't need initiate "filter_of_attention_map"!!!
    var_to_be_restored = [
        x for x in trainable if u'filter_of_attention_map' not in x.name
    ]
    var_to_be_restored = [
        x for x in var_to_be_restored if u'aggregated_feat' not in x.name
    ]

    uninitialized_vars = []
    uninitialized_vars.extend([
        x for x in tf.global_variables()
        if u'filter_of_attention_map' in x.name
    ])
    uninitialized_vars.extend(
        [x for x in tf.global_variables() if u'aggregated_feat' in x.name])
    uninitialized_vars.extend(
        [x for x in tf.global_variables() if u'Variable' in x.name])
    uninitialized_vars.extend(
        [x for x in tf.global_variables() if u'Momentum' in x.name])

    # Saver for storing checkpoints of the model.
    print("====var_to_be_restored shape check====")
    for tmp in var_to_be_restored:
        print("variable name: {},shape: {}, type: {}".format(
            tmp.name, tmp.get_shape(), type(tmp.name)))

    # Saver for storing checkpoints of the model.
    print("====uninitialized_vars shape check====")
    for tmp in uninitialized_vars:
        print("variable name: {},shape: {}, type: {}".format(
            tmp.name, tmp.get_shape(), type(tmp.name)))

    readSaver = tf.train.Saver(var_list=var_to_be_restored)
    writeSaver = tf.train.Saver(max_to_keep=40)
    if args.restore_from is not None:
        load(readSaver, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    init_new_vars_op = tf.variables_initializer(uninitialized_vars)
    sess.run(init_new_vars_op)

    print("====varaible initialize status check====")
    init_flag = sess.run(
        tf.stack(
            [tf.is_variable_initialized(v) for v in tf.global_variables()]))
    for v, flag in zip(tf.global_variables(), init_flag):
        if not flag:
            print("====warning====")
        print("name: {},  shape: {}, is_variable_initialized:{}".format(
            v.name, v.get_shape(), flag))

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    print("parameter_count =", sess.run(parameter_count))

    summary_str = sess.run(merged_summary_op)
    summary_writer.add_summary(summary_str)

    # Iterate over training steps.
    for step in range(1, args.num_steps):
        start_time = time.time()

        # get learning rate
        lr_scale = math.floor(step / 2000)
        cur_lr = args.learning_rate / math.pow(2, lr_scale)
        print("current learning rate: {}".format(cur_lr))


        _att_optim,_attention_loss,_attention_loss_1,_attention_loss_2,_attention_loss_3,_main_loss,\
        _main_loss_1, _pre_upscaled_1, _output_attention_map_1,\
        _main_loss_2, _pre_upscaled_2,_output_attention_map_2,\
        _main_loss_3, _pre_upscaled_3, _output_attention_map_3 = sess.run([att_optim,attention_loss,attention_loss_1,attention_loss_2,attention_loss_3,main_loss,main_loss_1, pre_upscaled_1, output_attention_map_1, main_loss_2, pre_upscaled_2,\
        output_attention_map_2, main_loss_3, pre_upscaled_3, output_attention_map_3],feed_dict={learning_rate:cur_lr,att_learning_rate:cur_lr})

        print(
            'step {:d}, main_loss: {:.5f}, loss 1: {:.5f}, loss 2: {:.5f}, loss 3: {:.5f}'
            .format(step, _main_loss, _main_loss_1, _main_loss_2,
                    _main_loss_3))
        print(
            'step {:d}, attention_loss: {:.5f}, att_loss 1: {:.5f}, att_loss 2: {:.5f}, att_loss 3: {:.5f}'
            .format(step, _attention_loss, _attention_loss_1,
                    _attention_loss_2, _attention_loss_3))

        if step % args.summary_freq == 0:
            print("write summay...")
            # generate summary for tensorboard
            summary_str = sess.run(merged_summary_op)
            summary_writer.add_summary(summary_str, step)

        if step % args.save_pred_every == 0:
            print("save a predict as picture...")
            #do predict
            preds_result_value, images, labels = sess.run(
                [pred_result, image_batch, label_batch])

            fig, axes = plt.subplots(args.save_num_images, 3, figsize=(16, 12))
            print("images type: {}".format(type(images)))
            print("labels type: {}".format(type(labels)))
            #print("preds_result_value type: {},shape {}".format(type(preds_result_value[0]),(preds_result_value[0]).get_shape()))

            print("preds_result shape: {}".format(preds_result_value.shape))
            for i in xrange(args.save_num_images):
                axes.flat[i * 3].set_title('data')
                axes.flat[i * 3].imshow(
                    (images[i] + IMG_MEAN)[:, :, ::-1].astype(np.uint8))

                axes.flat[i * 3 + 1].set_title('mask')
                axes.flat[i * 3 + 1].imshow(decode_labels(labels[i, :, :, 0]))

                axes.flat[i * 3 + 2].set_title('pred')
                axes.flat[i * 3 + 2].imshow(
                    decode_labels(preds_result_value[i, :, :, 0]))
            plt.savefig(args.save_dir + str(start_time) + ".png")
            plt.close(fig)
            save(writeSaver, sess, args.snapshot_dir, step)

        duration = time.time() - start_time
        print('step {:d} \t  ({:.3f} sec/step)'.format(step, duration))

    coord.request_stop()
    coord.join(threads)