def main(args):
    # ============ Setting the GPU used for model training ============ #
    logging.info("===> Setting the GPUs: {}".format(args.select_gpu))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.select_gpu

    # ===================== Definition of params ====================== #
    logging.info("===> Initialization")
    if args.gamma_A == 0:    # 3x3 -> 7x7
        inputs = tf.placeholder(tf.float32, [None, None, None, 3, 3, args.channels])
        groundtruth = tf.placeholder(tf.float32, [None, None, None, 7, 7, args.channels])
    elif args.gamma_A == 2:  # 5x5 -> 9x9
        inputs = tf.placeholder(tf.float32, [None, None, None, 5, 5, args.channels])
        groundtruth = tf.placeholder(tf.float32, [None, None, None, 9, 9, args.channels])
    elif args.gamma_A == 3:  # 3x3 -> 9x9
        inputs = tf.placeholder(tf.float32, [None, None, None, 3, 3, args.channels])
        groundtruth = tf.placeholder(tf.float32, [None, None, None, 9, 9, args.channels])
    elif args.gamma_A == 4:  # 2x2 -> 8x8
        inputs = tf.placeholder(tf.float32, [None, None, None, 2, 2, args.channels])
        groundtruth = tf.placeholder(tf.float32, [None, None, None, 8, 8, args.channels])
    else:
        inputs = None
        groundtruth = None
    is_training = tf.placeholder(tf.bool, [])

    HDDRNet = import_model(args.gamma_S, args.gamma_A)
    model = HDDRNet(inputs, groundtruth, is_training, args, state="TEST")

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)

    init = tf.global_variables_initializer()
    sess.run(init)

    # ================= Restore the pre-trained model ================= #
    logging.info("===> Resuming the pre-trained model.")
    saver = tf.train.Saver()
    try:
        saver.restore(sess, args.pretrained_model)
    except ValueError:
        logging.info("Pretrained model: {} not found.".format(args.pretrained_model))
        sys.exit(1)

    lflist = glob.glob(os.path.join(args.datafolder, '*.mat'))
    for i in range(len(lflist)):
        # ===================== Read light field data ===================== #
        logging.info("===> Reading the light field data")
        LF = sio.loadmat(lflist[i])["data"]
        LF = LF.transpose(2, 3, 0, 1)
        LF = np.expand_dims(LF, axis=0)
        LF = np.expand_dims(LF, axis=-1)

        # ================== Downsample the light field =================== #
        logging.info("===> Downsampling")
        Groundtruth, low_LF = downsampling(LF, rs=args.gamma_S, ra=args.gamma_A, nSig=1.2)
        Groundtruth = shave_batch_LFs(Groundtruth, border=(28, 28))
        Groundtruth = Groundtruth.squeeze()
        low_inLF = low_LF.astype(np.float32) / 255.

        # ============= Reconstruct the original light field ============== #
        logging.info("===> Reconstructing ......")

        # recons_LF = sess.run(model.Recons, feed_dict={inputs: low_inLF, is_training: False})
        recons_LF = ReconstructSpatialLFPatch(low_inLF, model, inputs, is_training, sess, args, stride=60, border=(28, 28))
        recons_LF = recons_LF.squeeze()
        recons_LF = np.uint8(recons_LF * 255.)

        logging.info("===> Calculating the mean PSNR and SSIM values (on luminance channel)......")
        meanPSNR = np.mean(ApertureWisePSNR(Groundtruth, recons_LF))
        meanSSIM = np.mean(ApertureWiseSSIM(Groundtruth, recons_LF))

        if not os.path.exists(args.result_folder):
            os.makedirs(args.result_folder)
        if not os.path.exists(os.path.join(args.result_folder, "HR")):
            os.makedirs(os.path.join(args.result_folder, "HR"))
        if not os.path.exists(os.path.join(args.result_folder, "SR")):
            os.makedirs(os.path.join(args.result_folder, "SR"))
        lfname = lflist[i].split('/')[-1]

        gtdict = {"data": Groundtruth}
        reconsdict = {"data": recons_LF}
        sio.savemat(os.path.join(args.result_folder, "HR", lfname), gtdict)
        sio.savemat(os.path.join(args.result_folder, "SR", lfname), reconsdict)

        logging.info("{0:+^74}".format(""))
        logging.info("|{0: ^72}|".format("Quantitative result for the scene: {}".format(lfname)))
        logging.info("|{0: ^72}|".format(""))
        logging.info("|{0: ^72}|".format("Method: HDDRNet |  Mean PSNR: {:.3f}      Mean SSIM: {:.3f}".format(meanPSNR, meanSSIM)))
        logging.info("{0:+^74}".format(""))
Exemplo n.º 2
0
def main(args):

    # ============ Setting the GPU used for model training ============ #
    logging.info("===> Setting the GPUs: {}".format(args.select_gpu))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.select_gpu

    # ===================== Definition of params ====================== #
    logging.info("===> Initialization")
    inputs = tf.placeholder(tf.float32, [
        args.batchSize, args.imageSize // args.gamma_S,
        args.imageSize // args.gamma_S, args.viewSize // args.gamma_A,
        args.viewSize // args.gamma_A, args.channels
    ])
    groundtruth = tf.placeholder(tf.float32, [
        args.batchSize, args.imageSize, args.imageSize, args.viewSize,
        args.viewSize, args.channels
    ])
    is_training = tf.placeholder(tf.bool, [])
    learning_rate = tf.placeholder(tf.float32, [])

    HDDRNet = import_model(args.gamma_S, args.gamma_A)
    model = HDDRNet(inputs, groundtruth, is_training, args)

    sess = tf.InteractiveSession(config=tf.ConfigProto(
        allow_soft_placement=True))
    opt = tf.train.AdamOptimizer(beta1=args.lr_beta1,
                                 learning_rate=learning_rate)
    train_op = opt.minimize(model.loss, var_list=model.net_variables)

    init = tf.global_variables_initializer()
    sess.run(init)

    # ============ Restore the VGG-19 network ============ #
    if args.perceptual_loss:
        logging.info("===> Restoring the VGG-19 Network for Perceptual Loss")
        var = tf.global_variables()
        vgg_var = [var_ for var_ in var if "vgg19" in var_.name]
        saver = tf.train.Saver(vgg_var)
        saver.restore(sess, args.vgg_model)

    # ============ Load the Train / Test Data ============ #
    logging.info("===> Loading the Training and Test Datasets")
    trainlist = glob.glob(os.path.join(args.datadir, "MSTrain/5x5/*/*.npy"))
    testlist = glob.glob(os.path.join(args.datadir, "MSTest/5x5/*.npy"))

    BESTPSNR = 0.0
    BESTSSIM = 0.0
    statetype = get_state(args.gamma_S, args.gamma_A)

    # =========== Restore the pre-trained model ========== #
    if args.resume:
        logging.info("Resuming the pre-trained model.")
        Epoch, BESTPSNR, BESTSSIM = read_stateinfo(
            os.path.join(args.save_folder, statetype))
        saver = tf.train.Saver()
        try:
            saver.restore(
                sess,
                os.path.join(args.save_folder, statetype,
                             "epoch_{:03d}".format(Epoch)))
            args.start_epoch = Epoch + 1
        except:
            logging.info("No saved model found.")
            args.start_epoch = 0

    logging.info("===> Start Training")

    for epoch in range(args.start_epoch, args.num_epoch):
        random.shuffle(trainlist)

        num_iter = len(trainlist) // args.batchSize
        lr = adjust_learning_rate(args.lr_start, epoch, step=20)

        for ii in range(num_iter):
            y_batch = np.load(trainlist[ii])
            x_batch = downsampling(y_batch,
                                   K1=args.gamma_S,
                                   nSig=1.2,
                                   spatial_only=True)

            y_batch = y_batch.astype(np.float32) / 255.
            x_batch = x_batch.astype(np.float32) / 255.

            angular_loss = 0.0
            spatial_loss = 0.0
            total_loss = 0.0
            for j in range(len(y_batch)):
                x = np.expand_dims(x_batch[j], axis=0)
                y = np.expand_dims(y_batch[j], axis=0)

                _, aloss, sloss, tloss, recons = sess.run(
                    [
                        train_op, model.angular_loss, model.spatial_loss,
                        model.loss, model.Recons
                    ],
                    feed_dict={
                        inputs: x,
                        groundtruth: y,
                        is_training: True,
                        learning_rate: lr
                    })

                angular_loss += aloss
                spatial_loss += sloss
                total_loss += tloss

            angular_loss /= len(y_batch)
            spatial_loss /= len(y_batch)
            total_loss /= len(y_batch)

            logging.info(
                "Epoch {:03d} [{:03d}/{:03d}]   Angular loss: {:.6f} | Spatial loss: {:.6f} | "
                "Total loss: {:.6f} | Learning rate: {:.10f}".format(
                    epoch, ii, num_iter, angular_loss, spatial_loss,
                    total_loss, lr))

        # ===================== Testing ===================== #

        logging.info("===> Start Testing for Epoch {:03d}".format(epoch))
        num_testiter = len(testlist) // args.batchSize
        test_psnr = 0.0
        test_ssim = 0.0
        test_angularloss = []
        test_spatialloss = []
        test_totalloss = []

        for kk in range(num_testiter):
            y_batch = np.load(testlist[kk])
            x_batch = downsampling(y_batch,
                                   K1=args.gamma_S,
                                   nSig=1.2,
                                   spatial_only=True)

            y_batch = y_batch.astype(np.float32) / 255.
            x_batch = x_batch.astype(np.float32) / 255.

            angular_loss = 0.0
            spatial_loss = 0.0
            total_loss = 0.0
            recons_batch = []
            for k in range(len(y_batch)):
                x = np.expand_dims(x_batch[k], axis=0)
                y = np.expand_dims(y_batch[k], axis=0)

                _, aloss, sloss, tloss, recons = sess.run(
                    [
                        train_op, model.angular_loss, model.spatial_loss,
                        model.loss, model.Recons
                    ],
                    feed_dict={
                        inputs: x,
                        groundtruth: y,
                        is_training: False,
                        learning_rate: lr
                    })

                angular_loss += aloss
                spatial_loss += sloss
                total_loss += tloss
                recons_batch.append(recons)

            angular_loss /= len(y_batch)  # average value for a single LF image
            spatial_loss /= len(y_batch)  # average value for a single LF image
            total_loss /= len(y_batch)  # average value for a single LF image

            recons_batch = np.concatenate(recons_batch, axis=0)
            recons_batch[recons_batch > 1.] = 1.
            recons_batch[recons_batch < 0.] = 0.
            item_psnr = batchmeanpsnr(
                y_batch, recons_batch)  # average value for a single LF image
            item_ssim = batchmeanssim(
                y_batch, recons_batch)  # average value for a single LF image

            test_angularloss.append(angular_loss)
            test_spatialloss.append(spatial_loss)
            test_totalloss.append(total_loss)
            test_psnr += item_psnr
            test_ssim += item_ssim

        test_psnr = test_psnr / len(testlist)
        test_ssim = test_ssim / len(testlist)
        avgtest_aloss = np.mean(test_angularloss)
        avgtest_sloss = np.mean(test_spatialloss)
        avgtest_tloss = np.mean(test_totalloss)
        test_dict = {
            "epoch": epoch,
            "TestAvgPSNR": test_psnr,
            "TestAvgSSIM": test_ssim,
            "TestAvgAngularLoss": avgtest_aloss,
            "TestAvgSpatialLoss": avgtest_sloss,
            "TestAvgTotalLoss": avgtest_tloss,
            "BESTPSNR": BESTPSNR,
            "BESTSSIM": BESTSSIM
        }

        if test_psnr > BESTPSNR:
            savefolder = os.path.join(args.save_folder, statetype, "BESTPSNR")
            path = save_model(sess, savefolder, epoch)
            test_dict["BESTPSNR"] = test_psnr
            save_stateinfo(savefolder, test_dict)
            logging.info("Model saved to {}".format(path))
            logging.info("PSNR: {:.6f}(previous) update to {:.6f}(current) "
                         "[BEST PSNR weights saved]".format(
                             BESTPSNR, test_psnr))
            BESTPSNR = test_psnr

        if test_ssim > BESTSSIM:
            savefolder = os.path.join(args.save_folder, statetype, "BESTSSIM")
            path = save_model(sess, savefolder, epoch)
            test_dict["BESTSSIM"] = test_ssim
            save_stateinfo(savefolder, test_dict)
            logging.info("Model saved to {}".format(path))
            logging.info("SSIM: {:.6f}(previous) update to {:.6f}(current) "
                         "[BEST SSIM weights saved]".format(
                             BESTSSIM, test_ssim))
            BESTSSIM = test_ssim

        # =================== Save the epoch training info ===================== #
        path = save_model(sess, os.path.join(args.save_folder, statetype),
                          epoch)
        save_stateinfo(os.path.join(args.save_folder, statetype), test_dict)
        logging.info("Model saved to {}".format(path))
def main(args):
    # ============ Setting the GPU used for model training ============ #
    logging.info("===> Setting the GPUs: {}".format(args.select_gpu))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.select_gpu

    # ===================== Definition of params ====================== #
    logging.info("===> Initialization")
    inputs = tf.placeholder(tf.float32, [args.batchSize, args.imageSize // args.gamma_S, args.imageSize // args.gamma_S,
                                         args.viewSize // 2 + 1, args.viewSize // 2 + 1, args.channels])
    groundtruth = tf.placeholder(tf.float32, [args.batchSize, args.imageSize, args.imageSize, args.viewSize,
                                              args.viewSize, args.channels])
    is_training = tf.placeholder(tf.bool, [])

    HDDRNet = import_model(args.gamma_S, args.gamma_A)
    model = HDDRNet(inputs, groundtruth, is_training, args, state="TEST")

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)

    init = tf.global_variables_initializer()
    sess.run(init)

    # ================= Restore the pre-trained model ================= #
    logging.info("===> Resuming the pre-trained model.")
    saver = tf.train.Saver()
    try:
        saver.restore(sess, args.pretrained_model)
    except ValueError:
        logging.info("Pretrained model: {} not found.".format(args.pretrained_model))
        sys.exit(1)

    # ===================== Read light field data ===================== #
    logging.info("===> Reading the light field data")
    LF = sio.loadmat(args.datapath)["data"]
    LF = LF.transpose(2, 3, 0, 1, 4)
    LF = shaveLF_by_factor(LF, args.gamma_S)
    LF = np.expand_dims(LF, axis=0)
    Groundtruth = shave_batch_LFs(LF, border=(3, 3))
    Groundtruth = Groundtruth.squeeze()

    # ================== Downsample the light field =================== #
    logging.info("===> Downsampling")
    _, low_LF = downsampling(LF, rs=args.gamma_S, ra=args.gamma_A, nSig=1.2)
    low_inLF = low_LF.astype(np.float32) / 255.

    # ============= Reconstruct the original light field ============== #
    logging.info("===> Reconstructing ......")
    recons_LF = SpatialReconstruction(low_inLF, model, inputs, is_training, sess, args, stride=60, border=(3, 3))
    recons_LF = recons_LF.squeeze()
    recons_LF = np.uint8(recons_LF * 255.)

    logging.info("===> Calculating the mean PSNR and SSIM values (on luminance channel)......")
    meanPSNR = np.mean(ApertureWisePSNR(Groundtruth, recons_LF))
    meanSSIM = np.mean(ApertureWiseSSIM(Groundtruth, recons_LF))

    logging.info("{0:+^74}".format(""))
    logging.info("|{0: ^72}|".format("Quantitative result for the scene: {}".format(args.datapath.split('/')[-1])))
    logging.info("|{0: ^72}|".format(""))
    logging.info("|{0: ^72}|".format("Method: HDDRNet |  Mean PSNR: {:.3f}      Mean SSIM: {:.3f}".format(meanPSNR, meanSSIM)))
    logging.info("{0:+^74}".format(""))