def training(FLAGS, is_finetune=False):
    max_steps = FLAGS.max_steps
    batch_size = FLAGS.batch_size
    train_dir = FLAGS.log_dir  # /tmp3/first350/TensorFlow/Logs
    image_dir = FLAGS.image_dir  # /tmp3/first350/SegNet-Tutorial/CamVid/train.txt
    val_dir = FLAGS.val_dir  # /tmp3/first350/SegNet-Tutorial/CamVid/val.txt
    finetune_ckpt = FLAGS.finetune
    image_w = FLAGS.image_w
    image_h = FLAGS.image_h
    image_c = FLAGS.image_c
    # should be changed if your model stored by different convention
    startstep = 0 if not is_finetune else int(FLAGS.finetune.split('-')[-1])

    image_filenames, label_filenames = get_filename_list(image_dir)
    val_image_filenames, val_label_filenames = get_filename_list(val_dir)

    with tf.Graph().as_default():

        train_data_node = tf.placeholder(
            tf.float32, shape=[batch_size, image_h, image_w, image_c])

        train_labels_node = tf.placeholder(
            tf.int64, shape=[batch_size, image_h, image_w, 1])

        phase_train = tf.placeholder(tf.bool, name='phase_train')

        global_step = tf.Variable(0, trainable=False)

        # For CamVid
        images, labels = CamVidInputs(image_filenames, label_filenames,
                                      batch_size)

        val_images, val_labels = CamVidInputs(val_image_filenames,
                                              val_label_filenames, batch_size)

        # Build a Graph that computes the logits predictions from the inference model.
        loss, eval_prediction = inference(train_data_node, train_labels_node,
                                          batch_size, phase_train)

        # Build a Graph that trains the model with one batch of examples and updates the model parameters.
        train_op = train(loss, global_step)

        saver = tf.train.Saver(tf.global_variables())

        summary_op = tf.summary.merge_all()

        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.0001)

        with tf.Session() as sess:
            # Build an initialization operation to run below.
            if (is_finetune == True):
                saver.restore(sess, finetune_ckpt)
            else:
                init = tf.global_variables_initializer()
                sess.run(init)

            # Start the queue runners.
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            # Summery placeholders
            summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
            average_pl = tf.placeholder(tf.float32)
            acc_pl = tf.placeholder(tf.float32)
            iu_pl = tf.placeholder(tf.float32)
            average_summary = tf.summary.scalar("test_average_loss",
                                                average_pl)
            acc_summary = tf.summary.scalar("test_accuracy", acc_pl)
            iu_summary = tf.summary.scalar("Mean_IU", iu_pl)

            for step in range(startstep, startstep + max_steps):
                image_batch, label_batch = sess.run([images, labels])
                # since we still use mini-batches in validation, still set bn-layer phase_train = True
                feed_dict = {
                    train_data_node: image_batch,
                    train_labels_node: label_batch,
                    phase_train: True
                }
                start_time = time.time()

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step % 10 == 0:
                    num_examples_per_step = batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), step, loss_value,
                                        examples_per_sec, sec_per_batch))

                    # eval current training batch pre-class accuracy
                    pred = sess.run(eval_prediction, feed_dict=feed_dict)
                    per_class_acc(pred, label_batch)

                if step % 100 == 0:
                    print("start validating.....")
                    total_val_loss = 0.0
                    hist = np.zeros((NUM_CLASSES, NUM_CLASSES))
                    for test_step in range(int(TEST_ITER)):
                        val_images_batch, val_labels_batch = sess.run(
                            [val_images, val_labels])

                        _val_loss, _val_pred = sess.run(
                            [loss, eval_prediction],
                            feed_dict={
                                train_data_node: val_images_batch,
                                train_labels_node: val_labels_batch,
                                phase_train: True
                            })
                        total_val_loss += _val_loss
                        hist += get_hist(_val_pred, val_labels_batch)
                    print("val loss: ", total_val_loss / TEST_ITER)
                    acc_total = np.diag(hist).sum() / hist.sum()
                    iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) -
                                          np.diag(hist))
                    test_summary_str = sess.run(
                        average_summary,
                        feed_dict={average_pl: total_val_loss / TEST_ITER})
                    acc_summary_str = sess.run(acc_summary,
                                               feed_dict={acc_pl: acc_total})
                    iu_summary_str = sess.run(
                        iu_summary, feed_dict={iu_pl: np.nanmean(iu)})
                    print_hist_summery(hist)
                    print(" end validating.... ")

                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.add_summary(test_summary_str, step)
                    summary_writer.add_summary(acc_summary_str, step)
                    summary_writer.add_summary(iu_summary_str, step)
                # Save the model checkpoint periodically.
                if step % 1000 == 0 or (step + 1) == max_steps:
                    checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

            coord.request_stop()
            coord.join(threads)
Beispiel #2
0
def model(one):
    global max_steps, batch_size, image_w, image_h, image_c, image_dir, log_dir, log_dir
    startstep = 9000
    finetune_ckpt = '/path/to/files/model.ckpt-9000'
    image_filenames, label_filenames = get_filename_list(image_dir)
    with tf.Graph().as_default():
        train_data_node = tf.placeholder(
            tf.float32, shape=[batch_size, image_h, image_w, image_c])
        train_labels_node = tf.placeholder(
            tf.int64, shape=[batch_size, image_h, image_w, 1])
        phase_train = tf.placeholder(tf.bool, name='phase_train')
        global_step = tf.Variable(0, trainable=False)

        images, labels = CamVidInputs(image_filenames, label_filenames,
                                      batch_size)
        print images.shape
        print labels.shape
        loss, eval_prediction = inference(train_data_node, train_labels_node,
                                          phase_train)
        train_op = train(loss, global_step)
        # print "image"+str(image.shape)
        # print "label"+str(label.shape)

        #		saver = tf.train.import_meta_graph('/home/shantams/deep-learning/tf/logs/model.ckpt-3000.meta')
        saver = tf.train.Saver(tf.global_variables())
        summary_op = tf.summary.merge_all()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            #	init = tf.global_variables_initializer()
            #	sess.run(init)
            #	saver.restore(sess,tf.train.latest_checkpoint('/home/shantams/deep-learning/tf/logs/checkpoint'))
            saver.restore(sess, finetune_ckpt)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            average_pl = tf.placeholder(tf.float32)
            acc_pl = tf.placeholder(tf.float32)
            iu_pl = tf.placeholder(tf.float32)
            average_summary = tf.summary.scalar("test_average_loss",
                                                average_pl)
            acc_summary = tf.summary.scalar("test_accuracy", acc_pl)
            iu_summary = tf.summary.scalar("Mean_IU", iu_pl)

            for step in range(startstep, 30000):
                image_batch, label_batch = sess.run([images, labels])
                # since we still use mini-batches in validation, still set bn-layer phase_train = True
                feed_dict = {
                    train_data_node: image_batch,
                    train_labels_node: label_batch,
                    phase_train: True
                }

                start_time = time.time()
                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
                duration = time.time() - start_time
                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step % 10 == 0:
                    num_examples_per_step = batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f(%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), step, loss_value,
                                        examples_per_sec, sec_per_batch))

                    # eval current training batch pre-class accuracy
                    pred = sess.run(eval_prediction, feed_dict=feed_dict)
                    per_class_acc(pred, label_batch)

                    if step % 1000 == 0 or (step + 1) == max_steps:
                        checkpoint_path = os.path.join(log_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
            coord.request_stop()
            coord.join(threads)
Beispiel #3
0
def training(FLAGS, is_finetune=False):
  max_steps = FLAGS.max_steps
  batch_size = FLAGS.batch_size
  train_dir = FLAGS.log_dir # /tmp3/first350/TensorFlow/Logs
  image_dir = FLAGS.image_dir # /tmp3/first350/SegNet-Tutorial/CamVid/train.txt
  val_dir = FLAGS.val_dir # /tmp3/first350/SegNet-Tutorial/CamVid/val.txt
  finetune_ckpt = FLAGS.finetune
  image_w = FLAGS.image_w
  image_h = FLAGS.image_h
  image_c = FLAGS.image_c
  # should be changed if your model stored by different convention
  startstep = 0 if not is_finetune else int(FLAGS.finetune.split('-')[-1])

  image_filenames, label_filenames = get_filename_list(image_dir)
  val_image_filenames, val_label_filenames = get_filename_list(val_dir)

  with tf.Graph().as_default():

    train_data_node = tf.placeholder( tf.float32, shape=[batch_size, image_h, image_w, image_c])

    train_labels_node = tf.placeholder(tf.int64, shape=[batch_size, image_h, image_w, 1])

    phase_train = tf.placeholder(tf.bool, name='phase_train')

    global_step = tf.Variable(0, trainable=False)

    # For CamVid
    images, labels = CamVidInputs(image_filenames, label_filenames, batch_size)

    val_images, val_labels = CamVidInputs(val_image_filenames, val_label_filenames, batch_size)

    # Build a Graph that computes the logits predictions from the inference model.
    loss, eval_prediction = inference(train_data_node, train_labels_node, batch_size, phase_train)

    # Build a Graph that trains the model with one batch of examples and updates the model parameters.
    train_op = train(loss, global_step)

    saver = tf.train.Saver(tf.global_variables())

    summary_op = tf.summary.merge_all()

    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.0001)

    with tf.Session() as sess:
      # Build an initialization operation to run below.
      if (is_finetune == True):
          saver.restore(sess, finetune_ckpt )
      else:
          init = tf.global_variables_initializer()
          sess.run(init)

      # Start the queue runners.
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      # Summery placeholders
      summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
      average_pl = tf.placeholder(tf.float32)
      acc_pl = tf.placeholder(tf.float32)
      iu_pl = tf.placeholder(tf.float32)
      average_summary = tf.summary.scalar("test_average_loss", average_pl)
      acc_summary = tf.summary.scalar("test_accuracy", acc_pl)
      iu_summary = tf.summary.scalar("Mean_IU", iu_pl)

      for step in range(startstep, startstep + max_steps):
        image_batch ,label_batch = sess.run([images, labels])
        # since we still use mini-batches in validation, still set bn-layer phase_train = True
        feed_dict = {
          train_data_node: image_batch,
          train_labels_node: label_batch,
          phase_train: True
        }
        start_time = time.time()

        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
        duration = time.time() - start_time

        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

        if step % 10 == 0:
          num_examples_per_step = batch_size
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
          print (format_str % (datetime.now(), step, loss_value,
                               examples_per_sec, sec_per_batch))

          # eval current training batch pre-class accuracy
          pred = sess.run(eval_prediction, feed_dict=feed_dict)
          per_class_acc(pred, label_batch)

        if step % 100 == 0:
          print("start validating.....")
          total_val_loss = 0.0
          hist = np.zeros((NUM_CLASSES, NUM_CLASSES))
          for test_step in range(int(TEST_ITER)):
            val_images_batch, val_labels_batch = sess.run([val_images, val_labels])

            _val_loss, _val_pred = sess.run([loss, eval_prediction], feed_dict={
              train_data_node: val_images_batch,
              train_labels_node: val_labels_batch,
              phase_train: True
            })
            total_val_loss += _val_loss
            hist += get_hist(_val_pred, val_labels_batch)
          print("val loss: ", total_val_loss / TEST_ITER)
          acc_total = np.diag(hist).sum() / hist.sum()
          iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
          test_summary_str = sess.run(average_summary, feed_dict={average_pl: total_val_loss / TEST_ITER})
          acc_summary_str = sess.run(acc_summary, feed_dict={acc_pl: acc_total})
          iu_summary_str = sess.run(iu_summary, feed_dict={iu_pl: np.nanmean(iu)})
          print_hist_summery(hist)
          print(" end validating.... ")

          summary_str = sess.run(summary_op, feed_dict=feed_dict)
          summary_writer.add_summary(summary_str, step)
          summary_writer.add_summary(test_summary_str, step)
          summary_writer.add_summary(acc_summary_str, step)
          summary_writer.add_summary(iu_summary_str, step)
        # Save the model checkpoint periodically.
        if step % 1000 == 0 or (step + 1) == max_steps:
          checkpoint_path = os.path.join(train_dir, 'model.ckpt')
          saver.save(sess, checkpoint_path, global_step=step)

      coord.request_stop()
      coord.join(threads)
Beispiel #4
0
def training(FLAGS, is_finetune=False):
    max_steps = FLAGS.max_steps
    batch_size = FLAGS.batch_size
    train_dir = FLAGS.log_dir  # /tmp3/first350/TensorFlow/Logs
    train_list = FLAGS.train_dir  # /tmp3/first350/SegNet-Tutorial/CamVid/train.txt
    val_dir = FLAGS.val_dir  # /tmp3/first350/SegNet-Tutorial/CamVid/val.txt
    finetune_ckpt = FLAGS.finetune
    not_restore_last = FLAGS.not_restore_last
    image_w = FLAGS.image_w
    image_h = FLAGS.image_h
    image_c = FLAGS.image_c
    datadir = FLAGS.datadir
    gpu_frac = FLAGS.gpu_usage
    dataset = FLAGS.dataset
    lr = FLAGS.learning_rate
    max_runtime = FLAGS.max_runtime
    use_weights = FLAGS.use_weights
    print("DEBUG 558: " + str(use_weights))

    max_time_seconds = 3600 * max_runtime
    TEST_ITER = int(get_dataset_params(dataset)["num_val"] / batch_size)
    EPOCH_ITER = int(get_dataset_params(dataset)["num_train"] / batch_size)
    EPOCHS_UNTIL_VAL = 1
    PATIENCE = FLAGS.patience
    print(TEST_ITER)
    print('batchsize training: ' + str(batch_size))
    print('max_steps training: ' + str(max_steps))
    print('lr training: ' + str(lr))
    print('Epochs until save ' + str(EPOCHS_UNTIL_VAL))

    # should be changed if your model stored by different convention
    startstep = 0 if not is_finetune else 0  # int((FLAGS.finetune.split('-')[-1]).split('.')[0])

    image_filenames, label_filenames = get_filename_list(train_list)
    val_image_filenames, val_label_filenames = get_filename_list(val_dir)

    with tf.Graph().as_default():

        train_data_node = tf.placeholder(
            tf.float32, shape=[batch_size, image_h, image_w, image_c])

        train_labels_node = tf.placeholder(
            tf.int64, shape=[batch_size, image_h, image_w, 1])

        phase_train = tf.placeholder(tf.bool, name='phase_train')

        global_step = tf.Variable(0, trainable=False)

        # select weights
        # class weight calculation used in segnet
        # weights for dataset de_top14
        # "detop15", "eutop25", "worldtiny2k", "kaggle_dstl", "vaihingen",
        #                            "detop15_nores", "eutop25_nores", "worldtiny2k_nores"
        if not use_weights:
            print('dont use weights')
            use_weights = tf.constant([1.0 for i in range(FLAGS.num_class)])
        else:
            print('use weights')

            if dataset == "detop15":
                use_weights = np.array([
                    0.975644, 1.025603, 0.601745, 6.600600, 1.328684, 0.454776
                ])
            elif dataset == "eutop25":
                use_weights = np.array([
                    0.970664, 1.031165, 0.790741, 5.320133, 1.384649, 0.718765
                ])
            elif dataset == "worldtiny2k":
                use_weights = np.array([
                    0.879195, 1.439660, 0.683112, 4.628286, 1.159291, 0.322113
                ])
            elif dataset == "eutop25_nores":
                use_weights = np.array(
                    [0.400486, 1.000000, 0.766842, 5.159342, 1.342801])
            elif dataset == "detop15_nores":
                use_weights = np.array(
                    [0.303529, 1.000000, 0.604396, 5.941638, 1.305352])
            elif dataset == "worldtiny2k_nores":
                use_weights = np.array(
                    [0.203351, 1.241845, 0.589249, 3.992340, 1.000000])
            elif dataset == "vaihingen":
                use_weights = np.array([
                    0.808506, 0.855016, 1.086051, 0.926584, 18.435326,
                    26.644663
                ])
            elif dataset == "kaggle_dstl":
                use_weights = np.array([
                    0.014317, 0.227888, 2.175962, 1.000000, 0.300450, 0.081639,
                    0.046646, 1.740426, 8.405148, 749.202109, 73.475000
                ])
            else:
                print('Error: No weights for dataset ' + dataset +
                      ' could be found.')

        # early stop variables
        last_val_loss_tf = tf.Variable(10000.0, name="last_loss")
        steps_total_tf = tf.Variable(0, name="steps_total")
        val_increased_t_tf = tf.Variable(0, name="loss_increased_t")

        # For Inputs
        images, labels = OSMInputs(image_filenames, label_filenames,
                                   batch_size, datadir, dataset)

        val_images, val_labels = OSMInputs(val_image_filenames,
                                           val_label_filenames, batch_size,
                                           datadir, dataset)

        # Build a Graph that computes the logits predictions from the inference model.
        loss, eval_prediction = inference(FLAGS.num_class, train_data_node,
                                          train_labels_node, batch_size,
                                          phase_train, use_weights)
        # Build a Graph that trains the model with one batch of examples and updates the model parameters.
        train_op = train(loss, global_step, lr)

        # print(vars_to_restore)
        # print([v.name for v in restore_var])
        # thanks to https://stackoverflow.com/a/50216949/8862202
        # v.name[:-2] to transform 'conv1_1_3x3_s2/weights:0' to 'conv1_1_3x3_s2/weights'
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)

        summary_op = tf.summary.merge_all()

        glob_start_time = time.time()
        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.0001)
        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_frac)
        with tf.Session() as sess:
            # init always (necessary for early stop vars)
            init = tf.global_variables_initializer()
            sess.run(init)

            # Build an initialization operation to run below.
            if is_finetune:
                ckpt = tf.train.get_checkpoint_state(finetune_ckpt)
                ckpt_path = ckpt.model_checkpoint_path

                # restore only vars previously (in the last save) defined
                vars_to_restore = get_tensors_in_checkpoint_file(
                    file_name=ckpt.model_checkpoint_path)
                #  print(vars_to_restore)
                #  print([v.name for v in tf.global_variables()])
                vars_to_restore = [
                    v for v in tf.global_variables()
                    if (not 'conv_classifier' in v.name or not FLAGS.
                        not_restore_last) and v.name[:-2] in vars_to_restore
                ]
                #  print(vars_to_restore)
                loader = tf.train.Saver(vars_to_restore, max_to_keep=10)

                # saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
                loader.restore(sess, ckpt_path)
                # saver.restore(sess, finetune_ckpt )
                print("Restored model parameters from {}".format(ckpt_path))
            else:
                init = tf.global_variables_initializer()
                sess.run(init)

            # debug, check early stop vars
            print(sess.run(last_val_loss_tf))
            print(sess.run(steps_total_tf))
            print(sess.run(val_increased_t_tf))

            # Start the queue runners.
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            # Summery placeholders
            summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
            average_pl = tf.placeholder(tf.float32)
            acc_pl = tf.placeholder(tf.float32)
            iu_pl = tf.placeholder(tf.float32)
            average_summary = tf.summary.scalar("test_average_loss",
                                                average_pl)
            acc_summary = tf.summary.scalar("test_accuracy", acc_pl)
            iu_summary = tf.summary.scalar("Mean_IU", iu_pl)

            last_val_loss = sess.run(last_val_loss_tf)
            val_not_imp_t = sess.run(val_increased_t_tf)
            total_steps = sess.run(steps_total_tf)
            for step in trange(startstep + total_steps,
                               startstep + max_steps + total_steps,
                               desc='training',
                               leave=True):
                image_batch, label_batch = sess.run([images, labels])
                # since we still use mini-batches in validation, still set bn-layer phase_train = True
                feed_dict = {
                    train_data_node: image_batch,
                    train_labels_node: label_batch,
                    phase_train: True
                }
                start_time = time.time()

                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
                duration = time.time() - start_time

                # assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

                if step % 100 == 0:
                    num_examples_per_step = batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    # time update
                    elapsed = time.time() - glob_start_time
                    remaining = max_time_seconds - elapsed

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch), %.1f seconds until stop')
                    print(format_str %
                          (datetime.now(), step, loss_value, examples_per_sec,
                           sec_per_batch, remaining))

                    # eval current training batch pre-class accuracy
                    pred = sess.run(eval_prediction, feed_dict=feed_dict)
                    per_class_acc(pred, label_batch, dataset=dataset)
                    # generate image and send it to event file
                    argmax_t = tf.argmax(pred, axis=3)
                    argmax = sess.run(argmax_t)
                    sat = image_batch[0][:][:][:]
                    sat = np.expand_dims(sat, axis=0)
                    im = predToLabelledImg(argmax[0])
                    gt = predToLabelledImg(label_batch[0])
                    # concat images to a single 4D vector to get all images in single line in tensorboard
                    sat_pred_gt = np.concatenate((sat, im, gt), 0)
                    sat_pred_gt_summary = tf.summary.image(
                        'sat_pred_gt', tf.convert_to_tensor(sat_pred_gt))
                    sat_pred_gt_summary = sess.run(sat_pred_gt_summary)
                    # img_sum_pred = tf.summary.image('pred img', tf.convert_to_tensor(im))
                    # img_sum_pred = sess.run(img_sum_pred)
                    # img_sum_gt = tf.summary.image('gt img', tf.convert_to_tensor(gt))
                    # img_sum_sat = tf.summary.image('satellite img', tf.convert_to_tensor(sat))
                    # img_sum_gt = sess.run(img_sum_gt)
                    # img_sum_sat = sess.run(img_sum_sat)
                    # debug line
                    # writeImage(argmax[0], str(step)+"_labelled.png", "osm")
                    #summary_writer.add_summary(img_sum_pred, step)
                    # summary_writer.add_summary(img_sum_gt, step)
                    summary_writer.add_summary(sat_pred_gt_summary, step)

                if step % EPOCH_ITER * EPOCHS_UNTIL_VAL == 0:
                    print("start validating.....")
                    total_val_loss = 0.0
                    hist = np.zeros((FLAGS.num_class, FLAGS.num_class))
                    for test_step in range(int(TEST_ITER)):
                        val_images_batch, val_labels_batch = sess.run(
                            [val_images, val_labels])

                        _val_loss, _val_pred = sess.run(
                            [loss, eval_prediction],
                            feed_dict={
                                train_data_node: val_images_batch,
                                train_labels_node: val_labels_batch,
                                phase_train: True
                            })
                        total_val_loss += _val_loss
                        hist += get_hist(_val_pred, val_labels_batch)
                    total_val_loss = total_val_loss / TEST_ITER
                    print("val loss: {:.3f} , last val loss: {:.3f}".format(
                        total_val_loss, last_val_loss))
                    acc_total = np.diag(hist).sum() / hist.sum()
                    iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) -
                                          np.diag(hist))
                    test_summary_str = sess.run(
                        average_summary,
                        feed_dict={average_pl: total_val_loss})
                    acc_summary_str = sess.run(acc_summary,
                                               feed_dict={acc_pl: acc_total})
                    iu_summary_str = sess.run(
                        iu_summary, feed_dict={iu_pl: np.nanmean(iu)})
                    print_hist_summery(hist, dataset=dataset)
                    print(" end validating.... ")

                    if total_val_loss > last_val_loss:
                        val_not_imp_t = val_not_imp_t + 1
                        if val_not_imp_t >= PATIENCE:
                            print(
                                "Terminated Training, Best Model (at step %d) saved %d validations ago"
                                % (best_model_step, PATIENCE))
                            f = open("./FINISHED_SEGNET", "w+")
                            f.close()
                            break

                    else:
                        val_not_imp_t = 0
                        best_model_step = step
                    print("Loss not since improved %d times" % val_not_imp_t)
                    last_val_loss = total_val_loss

                    # update early stop tensors
                    steps_assign = tf.assign(steps_total_tf, step)
                    last_val_assign = tf.assign(last_val_loss_tf,
                                                last_val_loss)
                    increased_assign = tf.assign(val_increased_t_tf,
                                                 val_not_imp_t)
                    print(sess.run(steps_assign))
                    print(sess.run(last_val_assign))
                    print(sess.run(increased_assign))

                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.add_summary(test_summary_str, step)
                    summary_writer.add_summary(acc_summary_str, step)
                    summary_writer.add_summary(iu_summary_str, step)
                    # Save the model checkpoint periodically.
                    checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
                    print("Checkpoint created at %s" % checkpoint_path)
                    # check if max run time is already over
                    elapsed = time.time() - glob_start_time
                    if (elapsed + 300) > max_time_seconds:
                        print("Training stopped: max run time elapsed")
                        os.remove("./RUNNING_SEGNET")
                        break

            coord.request_stop()
            coord.join(threads)