Ejemplo n.º 1
0
logs_path = "./logs/"
train_prefix = True
train_u_net = True
logging.basicConfig(filename='./LOG/train.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', \
                    level = logging.DEBUG,filemode='a',datefmt='%Y-%m-%d%I:%M:%S %p')
logging.info("init")

#step one dataset reconstruction
logging.info("step one dataset reconstruction")
data_it = ic.get_dataset_iterator()
img_input_Graph, Y_im_gt_Graph,inp_iso_Graph,inp_t_Graph,amp_ratio_Graph,gt_t_Graph , wb_graph = data_it.get_next()
#step two network reconstruction
logging.info("step two network reconstruction")

#Exposure Shifting Network(ESN)
Yhat_im_output_Graph , _ = net.quality_pri_net(img_input_Graph,wb_graph, \
                                                    inp_iso_Graph,inp_t_Graph,gt_t_Graph)

loss_MAE_Graph = tf.reduce_mean(tf.abs(Yhat_im_output_Graph - Y_im_gt_Graph))
loss_SSIM_Graph = cal_ssim_loss(       Yhat_im_output_Graph,  Y_im_gt_Graph)

loss_Graph = 0.85*(loss_MAE_Graph)  + 0.15*(loss_SSIM_Graph)

learning_rate_Graph = tf.Variable(initial_value =1e-4,trainable = False,name='lrG')
#var_list_ = [var for var in tf.trainable_variables() if 'prefix_net' in var.name]
var_list_ = tf.trainable_variables()
opt_Graph = tf.train.AdamOptimizer(learning_rate=learning_rate_Graph).minimize(loss_Graph,var_list=var_list_)

#################################################################################################################
tf.set_random_seed(2020)
sess = tf.Session()
# sess = tf_debug.LocalCLIDebugWrapperSession(sess)
Ejemplo n.º 2
0
dng_image_files = path_name

#tensorflow network input portal
img_input_Graph = tf.placeholder(tf.float32, [None, None, None, 4])
wb_graph = tf.placeholder(tf.float32, None)
inp_iso_Graph = tf.placeholder(tf.float32, None)
inp_t_Graph = tf.placeholder(tf.float32, None)
'''
Notice:   vars with '_Graph' are nodes in the tensorflow network,otherwise they are not
'''
# construct network frame
gt_t_prediction_Graph = net.brightness_predict_net(img_input_Graph, wb_graph,
                                                   inp_iso_Graph, inp_t_Graph)
Yhat_im_output_Graph , _ = net.quality_pri_net( img_input_Graph, \
                                                wb_graph, \
                                                inp_iso_Graph, \
                                                inp_t_Graph, \
                                                gt_t_prediction_Graph)

#tensorflow session and checkpoint/parameter loading
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(checkpointpath)
variables_to_restore = tf.contrib.framework.get_variables_to_restore()
saver_temp = tf.train.Saver(variables_to_restore)
saver_temp.restore(sess, ckpt.model_checkpoint_path)
print('ok')


def get_result(img_path, classID, time, iso):
    in_path = img_path
    ref_path = '00000000000'  # taking the place of ground truth path, ignore it