logs_path = "./logs/" train_prefix = True train_u_net = True logging.basicConfig(filename='./LOG/train.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', \ level = logging.DEBUG,filemode='a',datefmt='%Y-%m-%d%I:%M:%S %p') logging.info("init") #step one dataset reconstruction logging.info("step one dataset reconstruction") data_it = ic.get_dataset_iterator() img_input_Graph, Y_im_gt_Graph,inp_iso_Graph,inp_t_Graph,amp_ratio_Graph,gt_t_Graph , wb_graph = data_it.get_next() #step two network reconstruction logging.info("step two network reconstruction") #Exposure Shifting Network(ESN) Yhat_im_output_Graph , _ = net.quality_pri_net(img_input_Graph,wb_graph, \ inp_iso_Graph,inp_t_Graph,gt_t_Graph) loss_MAE_Graph = tf.reduce_mean(tf.abs(Yhat_im_output_Graph - Y_im_gt_Graph)) loss_SSIM_Graph = cal_ssim_loss( Yhat_im_output_Graph, Y_im_gt_Graph) loss_Graph = 0.85*(loss_MAE_Graph) + 0.15*(loss_SSIM_Graph) learning_rate_Graph = tf.Variable(initial_value =1e-4,trainable = False,name='lrG') #var_list_ = [var for var in tf.trainable_variables() if 'prefix_net' in var.name] var_list_ = tf.trainable_variables() opt_Graph = tf.train.AdamOptimizer(learning_rate=learning_rate_Graph).minimize(loss_Graph,var_list=var_list_) ################################################################################################################# tf.set_random_seed(2020) sess = tf.Session() # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
dng_image_files = path_name #tensorflow network input portal img_input_Graph = tf.placeholder(tf.float32, [None, None, None, 4]) wb_graph = tf.placeholder(tf.float32, None) inp_iso_Graph = tf.placeholder(tf.float32, None) inp_t_Graph = tf.placeholder(tf.float32, None) ''' Notice: vars with '_Graph' are nodes in the tensorflow network,otherwise they are not ''' # construct network frame gt_t_prediction_Graph = net.brightness_predict_net(img_input_Graph, wb_graph, inp_iso_Graph, inp_t_Graph) Yhat_im_output_Graph , _ = net.quality_pri_net( img_input_Graph, \ wb_graph, \ inp_iso_Graph, \ inp_t_Graph, \ gt_t_prediction_Graph) #tensorflow session and checkpoint/parameter loading sess = tf.Session() ckpt = tf.train.get_checkpoint_state(checkpointpath) variables_to_restore = tf.contrib.framework.get_variables_to_restore() saver_temp = tf.train.Saver(variables_to_restore) saver_temp.restore(sess, ckpt.model_checkpoint_path) print('ok') def get_result(img_path, classID, time, iso): in_path = img_path ref_path = '00000000000' # taking the place of ground truth path, ignore it