Exemplo n.º 1
0
def validation():
    img_names = sorted(glob("./data/real_world/val/others" + '/*'))
    psnr = []

    txt_path = "./result/%s/%04d/psnr_ssim.txt" % (model, epoch)
    f = open(txt_path, 'w')

    sess.run(validation_init_op)
    for id in range(val_num):
        fetch_list = [
            transmission_layer, reflection_layer, input_ambient, target,
            reflection, input_pureflash, target, lossDict
        ]
        pred_image_t, pred_image_r, gt_input_ambient, gt_target, gt_reflection, tmp_pureflash, tmp_T, crtDict = sess.run(
            fetch_list)
        print("Epc: %3d, shape of outputs: " % epoch, pred_image_t.shape,
              pred_image_r.shape)
        tmp_psnr = calculate_psnr(pred_image_t[0], tmp_T[0])
        psnr.append(tmp_psnr)
        f.writelines('%s: %.6f\n' % (img_names[0], tmp_psnr))
        utils.save_concat_img(
            gt_input_ambient, gt_target, gt_reflection, tmp_pureflash,
            pred_image_t, pred_image_r,
            "./result/%s/%04d/val_%06d.jpg" % (model, epoch, id))
    mean_psnr = np.mean(psnr)
    f.writelines('%s: %.6f\n' % ("average", mean_psnr))
    f.close()
    return mean_psnr
Exemplo n.º 2
0
def validation():
    img_names = sorted(glob("./data/real_world/val/others" + '/*'))
    psnr = []

    txt_path = "./result/%s/%04d/psnr_ssim.txt"%(model, epoch)
    f = open(txt_path,'w')
    for id in range(len(img_names) // 5):
        tmp_pureflash, tmp_ambient, tmp_flash, tmp_T, tmp_R = load_paired_data(img_names, id)
        h,w = tmp_T.shape[1:3]
        h = h // 32 * 32
        w = w // 32 * 32
        # tmp_T, tmp_R, tmp_ambient, tmp_pureflash, tmp_flash = tmp_T[:,:h:2,:w:2,:], tmp_R[:,:h:2,:w:2,:], tmp_ambient[:,:h:2,:w:2,:], tmp_pureflash[:,:h:2,:w:2,:], tmp_flash[:,:h:2,:w:2,:]
        tmp_T, tmp_R, tmp_ambient, tmp_pureflash, tmp_flash = tmp_T[:,:h,:w,:], tmp_R[:,:h,:w,:], tmp_ambient[:,:h,:w,:], tmp_pureflash[:,:h,:w,:], tmp_flash[:,:h,:w,:]        

        fetch_list=[transmission_layer, reflection_layer, input_ambient, target, reflection, lossDict]
        pred_image_t, pred_image_r, gt_input_ambient, gt_target, gt_reflection, crtDict=sess.run(fetch_list,
            feed_dict={input_ambient:tmp_ambient, reflection:tmp_R, target:tmp_T, input_pureflash:tmp_pureflash})
        print("Epc: %3d, shape of outputs: "%epoch, pred_image_t.shape, pred_image_r.shape)
        tmp_psnr = calculate_psnr(pred_image_t[0], tmp_T[0])
        psnr.append(tmp_psnr)
        f.writelines('%s: %.6f\n'%(img_names[0], tmp_psnr))
        utils.save_concat_img(gt_input_ambient, gt_target, gt_reflection, tmp_pureflash, pred_image_t,pred_image_r, "./result/%s/%04d/val_%06d.jpg"%(model, epoch, id))
    mean_psnr = np.mean(psnr)
    f.writelines('%s: %.6f\n'%("average", mean_psnr))
    f.close()
    return mean_psnr
def validation():
    img_names = sorted(glob("./data/real_world/val/others" + '/*'))
    psnr = []

    txt_path = "./result/%s/%04d/psnr_ssim.txt" % (model, epoch)
    f = open(txt_path, 'w')

    sess.run(validation_init_op)
    id = 0
    while True:
        try:
            fetch_list = [
                transmission_layer, shadow_mask_layer, img_with_shadow,
                img_no_shadow, shadow_mask, input_pureflash, img_no_shadow,
                lossDict
            ]
            pred_image_t, pred_image_r, gt_img_with_shadow, gt_img_no_shadow, gt_shadow_mask, tmp_pureflash, tmp_T, crtDict = sess.run(
                fetch_list)
            tmp_psnr = calculate_psnr(pred_image_t[0], tmp_T[0])
            psnr.append(tmp_psnr)
            f.writelines('%s: %.6f\n' % (img_names[0], tmp_psnr))
            if id % 100 == 0:
                print("Epc: %3d, shape of outputs: " % epoch,
                      pred_image_t.shape, pred_image_r.shape)
                utils.save_concat_img(
                    gt_img_with_shadow, gt_img_no_shadow, gt_shadow_mask,
                    tmp_pureflash, pred_image_t, pred_image_r,
                    "./result/%s/%04d/val_%06d.jpg" % (model, epoch, id))
            id += 1
        except tf.errors.OutOfRangeError:
            break

    mean_psnr = np.mean(psnr)
    print('%s: %.6f\n' % ("average", mean_psnr))
    f.writelines('%s: %.6f\n' % ("average", mean_psnr))
    f.close()
    return mean_psnr
                target: tmp_T,
                input_pureflash: tmp_pureflash,
                input_flash: tmp_flash
            })
        step += 1
        if step % 10 == 0:
            crtLoss_str = "   ".join([
                "{}: {:.3f}".format(key, value)
                for key, value in crtDict.items()
            ])
            print("Epc:{:03d}-{:04d} | {} time:{:.3f}".format(
                epoch, id, crtLoss_str,
                time.time() - st))
            if step % 100 == 0:
                utils.save_concat_img(
                    gt_input_ambient, gt_target, gt_reflection, tmp_pureflash,
                    pred_image_t, pred_image_r,
                    "./result/%s/%04d/train_%06d.jpg" % (model, epoch, id))

    mean_psnr = validation()
    if mean_psnr > best_psnr:
        best_psnr = mean_psnr
        print("mean: {:.2f}".format(mean_psnr))
        print("best: {:.2f}".format(best_psnr))
        saver.save(sess, "./result/%s/model.ckpt" % model)
        saver.save(sess, "./result/%s/%04d/model.ckpt" % (model, epoch - 1))
    if (is_test or (epoch % save_model_freq == 0 and epoch < 1000)):
        saver.save(sess, "./result/%s/model.ckpt" % model)
        saver.save(sess, "./result/%s/%04d/model.ckpt" % (model, epoch - 1))

        img_names = sorted(
            glob("./data/synthetic/with_corrn_reflection/test/others" +
            _, pred_image_t, pred_image_r, gt_img_with_shadow, gt_img_no_shadow, gt_shadow_mask, tmp_pureflash, crtDict = sess.run(
                fetch_list)

            step += 1
            if step % 100 == 0:
                crtLoss_str = "   ".join([
                    "{}: {:.3f}".format(key, value)
                    for key, value in crtDict.items()
                ])
                print("Epc:{:03d}-{:04d} | {} time:{:.3f}".format(
                    epoch, id, crtLoss_str,
                    time.time() - st))
                st = time.time()
                if step % 100 == 0:
                    utils.save_concat_img(
                        gt_img_with_shadow, gt_img_no_shadow, gt_shadow_mask,
                        tmp_pureflash, pred_image_t, pred_image_r,
                        "./result/%s/%04d/train_%06d.jpg" % (model, epoch, id))
            id += 1
        except tf.errors.OutOfRangeError:
            break

    print("Epc:{:03d} |  time:{:.3f}".format(epoch, time.time() - ept))
    mean_psnr = validation()
    if mean_psnr > best_psnr:
        best_psnr = mean_psnr
        print("mean: {:.2f}".format(mean_psnr))
        print("best: {:.2f}".format(best_psnr))
        saver.save(sess, "./result/%s/model.ckpt" % model)
        saver.save(sess, "./result/%s/%04d/model.ckpt" % (model, epoch - 1))
    if (False):  #(is_test or (epoch % save_model_freq == 0 and epoch < 1000)):
        saver.save(sess, "./result/%s/model.ckpt" % model)