def main(args): # ============ Setting the GPU used for model training ============ # logging.info("===> Setting the GPUs: {}".format(args.select_gpu)) os.environ["CUDA_VISIBLE_DEVICES"] = args.select_gpu # ===================== Definition of params ====================== # logging.info("===> Initialization") if args.gamma_A == 0: # 3x3 -> 7x7 inputs = tf.placeholder(tf.float32, [None, None, None, 3, 3, args.channels]) groundtruth = tf.placeholder(tf.float32, [None, None, None, 7, 7, args.channels]) elif args.gamma_A == 2: # 5x5 -> 9x9 inputs = tf.placeholder(tf.float32, [None, None, None, 5, 5, args.channels]) groundtruth = tf.placeholder(tf.float32, [None, None, None, 9, 9, args.channels]) elif args.gamma_A == 3: # 3x3 -> 9x9 inputs = tf.placeholder(tf.float32, [None, None, None, 3, 3, args.channels]) groundtruth = tf.placeholder(tf.float32, [None, None, None, 9, 9, args.channels]) elif args.gamma_A == 4: # 2x2 -> 8x8 inputs = tf.placeholder(tf.float32, [None, None, None, 2, 2, args.channels]) groundtruth = tf.placeholder(tf.float32, [None, None, None, 8, 8, args.channels]) else: inputs = None groundtruth = None is_training = tf.placeholder(tf.bool, []) HDDRNet = import_model(args.gamma_S, args.gamma_A) model = HDDRNet(inputs, groundtruth, is_training, args, state="TEST") config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=config) init = tf.global_variables_initializer() sess.run(init) # ================= Restore the pre-trained model ================= # logging.info("===> Resuming the pre-trained model.") saver = tf.train.Saver() try: saver.restore(sess, args.pretrained_model) except ValueError: logging.info("Pretrained model: {} not found.".format(args.pretrained_model)) sys.exit(1) lflist = glob.glob(os.path.join(args.datafolder, '*.mat')) for i in range(len(lflist)): # ===================== Read light field data ===================== # logging.info("===> Reading the light field data") LF = sio.loadmat(lflist[i])["data"] LF = LF.transpose(2, 3, 0, 1) LF = np.expand_dims(LF, axis=0) LF = np.expand_dims(LF, axis=-1) # ================== Downsample the light field =================== # logging.info("===> Downsampling") Groundtruth, low_LF = downsampling(LF, rs=args.gamma_S, ra=args.gamma_A, nSig=1.2) Groundtruth = shave_batch_LFs(Groundtruth, border=(28, 28)) Groundtruth = Groundtruth.squeeze() low_inLF = low_LF.astype(np.float32) / 255. # ============= Reconstruct the original light field ============== # logging.info("===> Reconstructing ......") # recons_LF = sess.run(model.Recons, feed_dict={inputs: low_inLF, is_training: False}) recons_LF = ReconstructSpatialLFPatch(low_inLF, model, inputs, is_training, sess, args, stride=60, border=(28, 28)) recons_LF = recons_LF.squeeze() recons_LF = np.uint8(recons_LF * 255.) logging.info("===> Calculating the mean PSNR and SSIM values (on luminance channel)......") meanPSNR = np.mean(ApertureWisePSNR(Groundtruth, recons_LF)) meanSSIM = np.mean(ApertureWiseSSIM(Groundtruth, recons_LF)) if not os.path.exists(args.result_folder): os.makedirs(args.result_folder) if not os.path.exists(os.path.join(args.result_folder, "HR")): os.makedirs(os.path.join(args.result_folder, "HR")) if not os.path.exists(os.path.join(args.result_folder, "SR")): os.makedirs(os.path.join(args.result_folder, "SR")) lfname = lflist[i].split('/')[-1] gtdict = {"data": Groundtruth} reconsdict = {"data": recons_LF} sio.savemat(os.path.join(args.result_folder, "HR", lfname), gtdict) sio.savemat(os.path.join(args.result_folder, "SR", lfname), reconsdict) logging.info("{0:+^74}".format("")) logging.info("|{0: ^72}|".format("Quantitative result for the scene: {}".format(lfname))) logging.info("|{0: ^72}|".format("")) logging.info("|{0: ^72}|".format("Method: HDDRNet | Mean PSNR: {:.3f} Mean SSIM: {:.3f}".format(meanPSNR, meanSSIM))) logging.info("{0:+^74}".format(""))
def main(args): # ============ Setting the GPU used for model training ============ # logging.info("===> Setting the GPUs: {}".format(args.select_gpu)) os.environ["CUDA_VISIBLE_DEVICES"] = args.select_gpu # ===================== Definition of params ====================== # logging.info("===> Initialization") inputs = tf.placeholder(tf.float32, [ args.batchSize, args.imageSize // args.gamma_S, args.imageSize // args.gamma_S, args.viewSize // args.gamma_A, args.viewSize // args.gamma_A, args.channels ]) groundtruth = tf.placeholder(tf.float32, [ args.batchSize, args.imageSize, args.imageSize, args.viewSize, args.viewSize, args.channels ]) is_training = tf.placeholder(tf.bool, []) learning_rate = tf.placeholder(tf.float32, []) HDDRNet = import_model(args.gamma_S, args.gamma_A) model = HDDRNet(inputs, groundtruth, is_training, args) sess = tf.InteractiveSession(config=tf.ConfigProto( allow_soft_placement=True)) opt = tf.train.AdamOptimizer(beta1=args.lr_beta1, learning_rate=learning_rate) train_op = opt.minimize(model.loss, var_list=model.net_variables) init = tf.global_variables_initializer() sess.run(init) # ============ Restore the VGG-19 network ============ # if args.perceptual_loss: logging.info("===> Restoring the VGG-19 Network for Perceptual Loss") var = tf.global_variables() vgg_var = [var_ for var_ in var if "vgg19" in var_.name] saver = tf.train.Saver(vgg_var) saver.restore(sess, args.vgg_model) # ============ Load the Train / Test Data ============ # logging.info("===> Loading the Training and Test Datasets") trainlist = glob.glob(os.path.join(args.datadir, "MSTrain/5x5/*/*.npy")) testlist = glob.glob(os.path.join(args.datadir, "MSTest/5x5/*.npy")) BESTPSNR = 0.0 BESTSSIM = 0.0 statetype = get_state(args.gamma_S, args.gamma_A) # =========== Restore the pre-trained model ========== # if args.resume: logging.info("Resuming the pre-trained model.") Epoch, BESTPSNR, BESTSSIM = read_stateinfo( os.path.join(args.save_folder, statetype)) saver = tf.train.Saver() try: saver.restore( sess, os.path.join(args.save_folder, statetype, "epoch_{:03d}".format(Epoch))) args.start_epoch = Epoch + 1 except: logging.info("No saved model found.") args.start_epoch = 0 logging.info("===> Start Training") for epoch in range(args.start_epoch, args.num_epoch): random.shuffle(trainlist) num_iter = len(trainlist) // args.batchSize lr = adjust_learning_rate(args.lr_start, epoch, step=20) for ii in range(num_iter): y_batch = np.load(trainlist[ii]) x_batch = downsampling(y_batch, K1=args.gamma_S, nSig=1.2, spatial_only=True) y_batch = y_batch.astype(np.float32) / 255. x_batch = x_batch.astype(np.float32) / 255. angular_loss = 0.0 spatial_loss = 0.0 total_loss = 0.0 for j in range(len(y_batch)): x = np.expand_dims(x_batch[j], axis=0) y = np.expand_dims(y_batch[j], axis=0) _, aloss, sloss, tloss, recons = sess.run( [ train_op, model.angular_loss, model.spatial_loss, model.loss, model.Recons ], feed_dict={ inputs: x, groundtruth: y, is_training: True, learning_rate: lr }) angular_loss += aloss spatial_loss += sloss total_loss += tloss angular_loss /= len(y_batch) spatial_loss /= len(y_batch) total_loss /= len(y_batch) logging.info( "Epoch {:03d} [{:03d}/{:03d}] Angular loss: {:.6f} | Spatial loss: {:.6f} | " "Total loss: {:.6f} | Learning rate: {:.10f}".format( epoch, ii, num_iter, angular_loss, spatial_loss, total_loss, lr)) # ===================== Testing ===================== # logging.info("===> Start Testing for Epoch {:03d}".format(epoch)) num_testiter = len(testlist) // args.batchSize test_psnr = 0.0 test_ssim = 0.0 test_angularloss = [] test_spatialloss = [] test_totalloss = [] for kk in range(num_testiter): y_batch = np.load(testlist[kk]) x_batch = downsampling(y_batch, K1=args.gamma_S, nSig=1.2, spatial_only=True) y_batch = y_batch.astype(np.float32) / 255. x_batch = x_batch.astype(np.float32) / 255. angular_loss = 0.0 spatial_loss = 0.0 total_loss = 0.0 recons_batch = [] for k in range(len(y_batch)): x = np.expand_dims(x_batch[k], axis=0) y = np.expand_dims(y_batch[k], axis=0) _, aloss, sloss, tloss, recons = sess.run( [ train_op, model.angular_loss, model.spatial_loss, model.loss, model.Recons ], feed_dict={ inputs: x, groundtruth: y, is_training: False, learning_rate: lr }) angular_loss += aloss spatial_loss += sloss total_loss += tloss recons_batch.append(recons) angular_loss /= len(y_batch) # average value for a single LF image spatial_loss /= len(y_batch) # average value for a single LF image total_loss /= len(y_batch) # average value for a single LF image recons_batch = np.concatenate(recons_batch, axis=0) recons_batch[recons_batch > 1.] = 1. recons_batch[recons_batch < 0.] = 0. item_psnr = batchmeanpsnr( y_batch, recons_batch) # average value for a single LF image item_ssim = batchmeanssim( y_batch, recons_batch) # average value for a single LF image test_angularloss.append(angular_loss) test_spatialloss.append(spatial_loss) test_totalloss.append(total_loss) test_psnr += item_psnr test_ssim += item_ssim test_psnr = test_psnr / len(testlist) test_ssim = test_ssim / len(testlist) avgtest_aloss = np.mean(test_angularloss) avgtest_sloss = np.mean(test_spatialloss) avgtest_tloss = np.mean(test_totalloss) test_dict = { "epoch": epoch, "TestAvgPSNR": test_psnr, "TestAvgSSIM": test_ssim, "TestAvgAngularLoss": avgtest_aloss, "TestAvgSpatialLoss": avgtest_sloss, "TestAvgTotalLoss": avgtest_tloss, "BESTPSNR": BESTPSNR, "BESTSSIM": BESTSSIM } if test_psnr > BESTPSNR: savefolder = os.path.join(args.save_folder, statetype, "BESTPSNR") path = save_model(sess, savefolder, epoch) test_dict["BESTPSNR"] = test_psnr save_stateinfo(savefolder, test_dict) logging.info("Model saved to {}".format(path)) logging.info("PSNR: {:.6f}(previous) update to {:.6f}(current) " "[BEST PSNR weights saved]".format( BESTPSNR, test_psnr)) BESTPSNR = test_psnr if test_ssim > BESTSSIM: savefolder = os.path.join(args.save_folder, statetype, "BESTSSIM") path = save_model(sess, savefolder, epoch) test_dict["BESTSSIM"] = test_ssim save_stateinfo(savefolder, test_dict) logging.info("Model saved to {}".format(path)) logging.info("SSIM: {:.6f}(previous) update to {:.6f}(current) " "[BEST SSIM weights saved]".format( BESTSSIM, test_ssim)) BESTSSIM = test_ssim # =================== Save the epoch training info ===================== # path = save_model(sess, os.path.join(args.save_folder, statetype), epoch) save_stateinfo(os.path.join(args.save_folder, statetype), test_dict) logging.info("Model saved to {}".format(path))
def main(args): # ============ Setting the GPU used for model training ============ # logging.info("===> Setting the GPUs: {}".format(args.select_gpu)) os.environ["CUDA_VISIBLE_DEVICES"] = args.select_gpu # ===================== Definition of params ====================== # logging.info("===> Initialization") inputs = tf.placeholder(tf.float32, [args.batchSize, args.imageSize // args.gamma_S, args.imageSize // args.gamma_S, args.viewSize // 2 + 1, args.viewSize // 2 + 1, args.channels]) groundtruth = tf.placeholder(tf.float32, [args.batchSize, args.imageSize, args.imageSize, args.viewSize, args.viewSize, args.channels]) is_training = tf.placeholder(tf.bool, []) HDDRNet = import_model(args.gamma_S, args.gamma_A) model = HDDRNet(inputs, groundtruth, is_training, args, state="TEST") config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=config) init = tf.global_variables_initializer() sess.run(init) # ================= Restore the pre-trained model ================= # logging.info("===> Resuming the pre-trained model.") saver = tf.train.Saver() try: saver.restore(sess, args.pretrained_model) except ValueError: logging.info("Pretrained model: {} not found.".format(args.pretrained_model)) sys.exit(1) # ===================== Read light field data ===================== # logging.info("===> Reading the light field data") LF = sio.loadmat(args.datapath)["data"] LF = LF.transpose(2, 3, 0, 1, 4) LF = shaveLF_by_factor(LF, args.gamma_S) LF = np.expand_dims(LF, axis=0) Groundtruth = shave_batch_LFs(LF, border=(3, 3)) Groundtruth = Groundtruth.squeeze() # ================== Downsample the light field =================== # logging.info("===> Downsampling") _, low_LF = downsampling(LF, rs=args.gamma_S, ra=args.gamma_A, nSig=1.2) low_inLF = low_LF.astype(np.float32) / 255. # ============= Reconstruct the original light field ============== # logging.info("===> Reconstructing ......") recons_LF = SpatialReconstruction(low_inLF, model, inputs, is_training, sess, args, stride=60, border=(3, 3)) recons_LF = recons_LF.squeeze() recons_LF = np.uint8(recons_LF * 255.) logging.info("===> Calculating the mean PSNR and SSIM values (on luminance channel)......") meanPSNR = np.mean(ApertureWisePSNR(Groundtruth, recons_LF)) meanSSIM = np.mean(ApertureWiseSSIM(Groundtruth, recons_LF)) logging.info("{0:+^74}".format("")) logging.info("|{0: ^72}|".format("Quantitative result for the scene: {}".format(args.datapath.split('/')[-1]))) logging.info("|{0: ^72}|".format("")) logging.info("|{0: ^72}|".format("Method: HDDRNet | Mean PSNR: {:.3f} Mean SSIM: {:.3f}".format(meanPSNR, meanSSIM))) logging.info("{0:+^74}".format(""))