Example #1
0
def evaluate():
    # A list used to save all psnr and ssim.
    psnr_list = []
    ssim_list = []
    # Read all hazed images indexes and clear images from directory
    if not df.FLAGS.eval_only_haze:
        di.image_input(df.FLAGS.clear_test_images_dir, _clear_test_file_names, _clear_test_img_list,
                       _clear_test_directory, clear_image=True)
        if len(_clear_test_img_list) == 0:
            raise RuntimeError("No image found! Please supply clear images for training or eval ")
    # Hazed training image pre-process
    di.image_input(df.FLAGS.haze_test_images_dir, _hazed_test_file_names, _hazed_test_img_list,
                   clear_dict=None, clear_image=False)
    if len(_hazed_test_img_list) == 0:
        raise RuntimeError("No image found! Please supply hazed images for training or eval ")

    for image in _hazed_test_img_list:
        graph = tf.Graph()
        with graph.as_default():
            # ########################################################################
            # ########################Load images from disk##############################
            # ########################################################################
            # Read image from files and append them to the list
            hazed_image = im.open(image.path)
            hazed_image = hazed_image.convert('RGB')
            shape = np.shape(hazed_image)
            hazed_image_placeholder = tf.placeholder(tf.float32, shape=[constant.SINGLE_IMAGE_NUMBER, shape[0], shape[1], constant.RGB_CHANNEL])
            hazed_image_arr = np.array(hazed_image)
            float_hazed_image = hazed_image_arr.astype('float32') / 255
            if not df.FLAGS.eval_only_haze:
                clear_image = di.find_corres_clear_image(image, _clear_test_directory)
                clear_image_arr = np.array(clear_image)

            # ########################################################################
            # ###################Restore model and do evaluations#####################
            # ########################################################################
            gman = model.GMAN_V1()
            logist = gman.inference(hazed_image_placeholder, batch_size=1, h=shape[0], w=shape[1])
            variable_averages = tf.train.ExponentialMovingAverage(
                constant.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)
            # saver, train_op, hazed_image, clear_image_arr, hazed_images_obj, placeholder, psnr_list, ssim_list, h, w
            if not df.FLAGS.eval_only_haze:
                eval_once(graph, saver, logist, float_hazed_image, clear_image_arr, image, hazed_image_placeholder,
                          psnr_list, ssim_list, shape[0], shape[1])
            else:
                eval_once(graph, saver, logist, float_hazed_image, None, image, hazed_image_placeholder,
                          psnr_list, ssim_list, shape[0], shape[1])

    if not df.FLAGS.eval_only_haze:
        psnr_avg = cal_average(psnr_list)
        format_str = 'Average PSNR: %5f'
        logger.info(format_str % psnr_avg)
        ssim_avg = cal_average(ssim_list)
        format_str = 'Average SSIM: %5f'
        logger.info(format_str % ssim_avg)
def input_create_flow_control_json():
    # Current json file doesn't exist
    flow_control_file = open(df.FLAGS.train_json_path, "w")
    try:
        flow_control = {'train_flow_control': []}
        json.dump(flow_control, flow_control_file)
        logger.info("Create Json file for training flow control.")
    except IOError as err:
        raise RuntimeError("[Error]: Error happens when read/write " + df.FLAGS.train_json_path + ".")
    finally:
        flow_control_file.close()
    return flow_control["train_flow_control"]
 def save(self, current_learning_rate):
     current_learning_rate = str(current_learning_rate)
     if not os.path.exists(self.path):
         logger.info("Create Json file for learning rate.")
     learning_rate_file = open(self.path, "w")
     try:
         learning_rate = {'learning_rate': current_learning_rate}
         json.dump(learning_rate, learning_rate_file)
     except IOError as err:
         raise RuntimeError("[Error]: Error happens when read/write " +
                            self.path + ".")
     finally:
         learning_rate_file.close()
     return float(learning_rate["learning_rate"])
Example #4
0
def eval_once(graph, saver, train_op, hazed_image, clear_image, hazed_images_obj, placeholder, psnr_list, ssim_list, h, w):
    with tf.Session(graph= graph, config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=df.FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True,
                                              per_process_gpu_memory_fraction=1,
                                              visible_device_list="0"))) as sess:
        ckpt = tf.train.get_checkpoint_state(df.FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found')
            return
        start = time.time()
        prediction = sess.run([train_op], feed_dict={placeholder: [hazed_image]})
        duration = time.time() - start
        # Run the session and get the prediction of one clear image
        dehazed_image = write_images_to_file(prediction, hazed_images_obj, h, w, sess)
        if not df.FLAGS.eval_only_haze:
            psnr_value = cal_psnr(dehazed_image, np.uint(clear_image))
            ssim_value = measure.compare_ssim(np.uint8(dehazed_image), np.uint8(clear_image), multichannel=True)
            ssim_list.append(ssim_value)
            psnr_list.append(psnr_value)
            logger.info('-------------------------------------------------------------------------------------------------------------------------------')
            format_str = 'image: %s PSNR: %f; SSIM: %f; (%.4f seconds)'
            logger.info(format_str % (hazed_images_obj.path, psnr_value, ssim_value, duration))
            logger.info('-------------------------------------------------------------------------------------------------------------------------------')
        else:
            print('-------------------------------------------------------------------------------------------------------------------------------')
            format_str = 'image: %s (%.4f seconds)'
            logger.info(format_str % (hazed_images_obj.path, duration))
            print('-------------------------------------------------------------------------------------------------------------------------------')
    sess.close()
def input_create_tfrecord_json():
    tfrecord_list = os.listdir(df.FLAGS.tfrecord_path)
    tfrecord_status_file = open(df.FLAGS.tfrecord_json, "w")
    try:
        # create dictionary for tf-record names
        # key(String): name of tf-record : value(Boolean): if existing
        tfrecord_existing_dict = {"tfrecord_status": {}}
        # {filename-0.tfrecords : False ...  filename-max_epoch-1.tfrecords : False}
        for index in range(500):
            tfrecord_name = df.FLAGS.tfrecord_format % index
            if tfrecord_list.__contains__(tfrecord_name):
                tfrecord_existing_dict["tfrecord_status"][tfrecord_name] = constant.INPUT_TFRECORD_COMPLETE
            else:
                tfrecord_existing_dict["tfrecord_status"][tfrecord_name] = constant.INPUT_TFRECORD_NOT_COMPLETE
        json.dump(tfrecord_existing_dict, tfrecord_status_file)
        logger.info("Create Json file for record tf-record.")
    except IOError as err:
        raise RuntimeError("[Error]: Error happens when read/write " + df.FLAGS.tfrecord_json + ".")
    finally:
        tfrecord_status_file.close()
    return tfrecord_existing_dict
def train(tf_record_path, image_number, config):
    logger.info("Training on: %s" % tf_record_path)
    tf.reset_default_graph()
    with tf.Graph().as_default():
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
        # Calculate the learning rate schedule.
        if constant.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN < df.FLAGS.batch_size:
            raise RuntimeError(' NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN cannot smaller than batch_size!')
        num_batches_per_epoch = (constant.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
                                 df.FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch * constant.NUM_EPOCHS_PER_DECAY)

        initial_learning_rate = learning_rate.LearningRate(constant.INITIAL_LEARNING_RATE, df.FLAGS.train_learning_rate)
        lr = tf.train.exponential_decay(initial_learning_rate.load(),
                                        global_step,
                                        decay_steps,
                                        initial_learning_rate.decay_factor,
                                        staircase=True)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(lr)
        # opt = tf.train.GradientDescentOptimizer(lr)

        batch_queue = di.input_get_queue_from_tfrecord(tf_record_path, df.FLAGS.batch_size,
                                                       df.FLAGS.input_image_height, df.FLAGS.input_image_width)
        # Calculate the gradients for each model tower.
        # vgg_per = Vgg16()
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            gman_model = model.GMAN_V1()
            gman_net = net.Net(gman_model)
            for i in range(df.FLAGS.num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % (constant.TOWER_NAME, i)) as scope:
                        gman_tower = tower.GMEAN_Tower(gman_net, batch_queue, scope, tower_grads, opt)
                        summaries, loss = gman_tower.process()

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = tower.Tower.average_gradients(tower_grads)
        # Add a summary to track the learning rate.
        summaries.append(tf.summary.scalar('learning_rate', lr))

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Add histograms for gradients.
        for grad, var in grads:
            if grad is not None:
                summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad))

        # Track the moving averages of all trainable variables.
        variable_averages = tf.train.ExponentialMovingAverage(constant.MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables())

        # Group all updates to into a single train op.
        # , variables_averages_op
        train_op = tf.group(apply_gradient_op, variables_averages_op)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=df.FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=constant.TRAIN_GPU_MEMORY_ALLOW_GROWTH,
                                      per_process_gpu_memory_fraction=constant.TRAIN_GPU_MEMORY_FRACTION,
                                      visible_device_list=constant.TRAIN_VISIBLE_GPU_LIST))
        )

        # Restore previous trained model
        if config[dc.CONFIG_TRAINING_TRAIN_RESTORE]:
            train_load_previous_model(df.FLAGS.train_dir, saver, sess)
        else:
            sess.run(init)

        coord = tf.train.Coordinator()
        # Start the queue runners.
        queue_runners = tf.train.start_queue_runners(sess=sess, coord=coord, daemon=False)

        summary_writer = tf.summary.FileWriter(df.FLAGS.train_dir, sess.graph)
        max_step = int((image_number / df.FLAGS.batch_size) * 2)
        # For each tf-record, we train them twice.
        for step in range(max_step):
            start_time = time.time()
            if step != 0 and (step % 1000 == 0 or (step + 1) == max_step):
                _, loss_value, current_learning_rate = sess.run([train_op, loss, lr])
            else:
                _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = df.FLAGS.batch_size * df.FLAGS.num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration / df.FLAGS.num_gpus

                format_str = ('%s: step %d, loss = %.8f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            if step % 1000 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step != 0 and (step % 1000 == 0 or (step + 1) == max_step):
                checkpoint_path = os.path.join(df.FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                initial_learning_rate.save(current_learning_rate)

        coord.request_stop()
        sess.close()
        coord.join(queue_runners, stop_grace_period_secs=constant.TRAIN_STOP_GRACE_PERIOD, ignore_live_threads=True)
    logger.info("=========================================================================================")