def infer(data_path, model):
    psnr = utils.AvgrageMeter()
    ssim = utils.AvgrageMeter()

    model.eval()
    transforms = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()])

    with torch.no_grad():
        for step, pt in enumerate(glob.glob(data_path)):
            image = np.array(Image.open(pt))

            clear_image = utils.crop_img(image[:, :image.shape[1] // 2, :],
                                         base=args.patch_size)
            rain_image = utils.crop_img(image[:, image.shape[1] // 2:, :],
                                        base=args.patch_size)

            # # Test on whole image
            # input = transforms(rain_image).unsqueeze(dim=0).cuda()
            # target = transforms(clear_image).unsqueeze(dim=0).cuda(async=True)
            # logits = model(input)
            # n = input.size(0)

            # Test on whole image with data augmentation
            target = transforms(clear_image).unsqueeze(dim=0).cuda()
            for i in range(8):
                im = utils.data_augmentation(rain_image, i)
                input = transforms(im.copy()).unsqueeze(dim=0).cuda()
                begin_time = time.time()
                if i == 0:
                    logits = utils.inverse_augmentation(
                        model(input).cpu().numpy().transpose(0, 2, 3, 1)[0], i)
                else:
                    logits = logits + utils.inverse_augmentation(
                        model(input).cpu().numpy().transpose(0, 2, 3, 1)[0], i)
                end_time = time.time()
            n = input.size(0)
            logits = transforms(logits / 8).unsqueeze(dim=0).cuda()

            # # Test on patches2patches
            # noise_patches = utils.slice_image2patches(rain_image, patch_size=args.patch_size)
            # image_patches = utils.slice_image2patches(clear_image, patch_size=args.patch_size)
            # input = torch.tensor(noise_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda()
            # target = torch.tensor(image_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda()
            # logits = model(input)
            # n = input.size(0)

            s = pytorch_ssim.ssim(torch.clamp(logits, 0, 1), target)
            p = utils.compute_psnr(
                np.clip(logits.detach().cpu().numpy(), 0, 1),
                target.detach().cpu().numpy())
            psnr.update(p, n)
            ssim.update(s, n)
            print('psnr:%6f ssim:%6f' % (p, s))

            # Image.fromarray(rain_image).save(args.save+'/'+str(step)+'_noise.png')
            # Image.fromarray(np.clip(logits[0].cpu().numpy().transpose(1,2,0)*255, 0, 255).astype(np.uint8)).save(args.save+'/'+str(step)+'_denoised.png')

    return psnr.avg, ssim.avg
Esempio n. 2
0
def evaluate_prediction():
    groundtruthfile = '/nrs/saalfeld/heinrichl/SR-data/FIBSEM/downscaled/bigh5-16isozyx/validation.h5'
    predictionfile = '/nrs/saalfeld/heinrichl/results_keras/Unet3-32-2_wogt_10cubic/finetuning_avg10weights_lrs1' \
                     '/validation30.h5'
    gt_arr = h5py.File(groundtruthfile, 'r')['raw']
    pred_arr = h5py.File(predictionfile, 'r')['raw']
    gt_arr = utils.cut_to_sc(gt_arr, 4, 0)
    pred_arr = utils.cut_to_sc(pred_arr, 4, 0)
    border = utils.get_bg_borders(pred_arr)
    gt_arr = utils.cut_to_size(gt_arr, border)
    pred_arr = utils.cut_to_size(pred_arr, border)
    print(utils.compute_psnr(pred_arr, gt_arr))
    print(utils.compute_wpsnr(pred_arr, gt_arr))
Esempio n. 3
0
def main():
    ## data
    print('Loading data...')
    test_hr_path = os.path.join('data/', dataset)
    if dataset == 'Set5':
        ext = '*.bmp'
    else:
        ext = '*.png'
    hr_paths = sorted(glob.glob(os.path.join(test_hr_path, ext)))

    ## model
    print('Loading model...')
    tensor_lr = tf.placeholder('float32', [1, None, None, 3], name='tensor_lr')
    tensor_b = tf.placeholder('float32', [1, None, None, 3], name='tensor_b')

    tensor_sr = IDN(tensor_lr, tensor_b, scale)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, model_path)

    ## result
    save_path = os.path.join(saved_path, dataset + '/x' + str(scale))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    psnr_score = 0
    for i, _ in enumerate(hr_paths):
        print('processing image %d' % (i + 1))
        img_hr = utils.modcrop(misc.imread(hr_paths[i]), scale)
        img_lr = utils.downsample_fn(img_hr, scale=scale)
        img_b = utils.upsample_fn(img_lr, scale=scale)
        [lr, b] = utils.datatype([img_lr, img_b])
        lr = lr[np.newaxis, :, :, :]
        b = b[np.newaxis, :, :, :]
        [sr] = sess.run([tensor_sr], {tensor_lr: lr, tensor_b: b})
        sr = utils.quantize(np.squeeze(sr))
        img_sr = utils.shave(sr, scale)
        img_hr = utils.shave(img_hr, scale)
        if not rgb:
            img_pre = utils.quantize(sc.rgb2ycbcr(img_sr)[:, :, 0])
            img_label = utils.quantize(sc.rgb2ycbcr(img_hr)[:, :, 0])
        else:
            img_pre = img_sr
            img_label = img_hr
        psnr_score += utils.compute_psnr(img_pre, img_label)
        misc.imsave(os.path.join(save_path, os.path.basename(hr_paths[i])), sr)

    print('Average PSNR: %.4f' % (psnr_score / len(hr_paths)))
    print('Finish')
Esempio n. 4
0
def test(name_img, model, img, sr_factor, gt=False, img_gt=None):
    out_file = r'../output'
    if not os.path.exists(out_file):
        os.makedirs(out_file)
    model.eval()

    img_bicubic = img.resize(
        (int(img.size[0] * sr_factor), int(img.size[1] * sr_factor)),
        resample=PIL.Image.BICUBIC)
    img_bicubic.save(os.path.join(out_file, name_img + '_bicubic.png'))

    input = transforms.ToTensor()(img_bicubic)
    input = torch.unsqueeze(input, 0)
    input = input.to(device)
    with torch.no_grad():
        out = model(input)
    out = out.data.cpu()
    out = out.clamp(min=0, max=1)
    out = torch.squeeze(out, 0)
    out = transforms.ToPILImage()(out)
    out.save(os.path.join(out_file, name_img + '_zssr.png'))

    if gt:
        ssim_bicubic = compute_ssim(img_gt, img_bicubic)
        psnr_bicubic = compute_psnr(img_gt, img_bicubic)
        ssim_zssr = compute_ssim(img_gt, out)
        psnr_zssr = compute_psnr(img_gt, out)
        print("psnr_bicubic:\t{:.2f}".format(psnr_bicubic))
        print("ssim_bicubic:\t{:.4f}".format(ssim_bicubic))
        print("psnr_zssr:\t{:.2f}".format(psnr_zssr))
        print("ssim_zssr:\t{:.4f}".format(ssim_zssr))
        fo = open(os.path.join(out_file, 'PSNR_and_SSIM.txt'), mode='a')
        fo.write(str(name_img) + ':\n')
        fo.write(
            '\tbicubic: psnr:{:.2f}\tssim:{:.4f}\tzssr: psnr:{:.2f}\tssim:{:.4f}\n'
            .format(psnr_bicubic, ssim_bicubic, psnr_zssr, ssim_zssr))
        return ssim_bicubic, psnr_bicubic, ssim_zssr, psnr_zssr
def infer(valid_queue, model):
  psnr = utils.AvgrageMeter()
  ssim = utils.AvgrageMeter()
  loss = utils.AvgrageMeter()
  model.eval()

  with torch.no_grad():
    for _, (input, target) in enumerate(valid_queue):
      input = input.cuda()
      target = target.cuda()
      logits = model(input)

      l = MSELoss(logits, target)
      s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target)
      p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy())
      n = input.size(0)
      psnr.update(p, n)
      ssim.update(s, n)
      loss.update(l, n)
  
  return psnr.avg, ssim.avg, loss.avg
Esempio n. 6
0
def test():
    avg_psnr = 0

    for batch in testing_data_loader:
        input, target = batch[0].detach(), batch[1].detach()
        model.feed_data([input], need_HR=False)
        model.test()
        pre = model.get_current_visuals(need_HR=False)
        sr_img = utils.tensor2np(pre['SR'].data)
        gt_img = utils.tensor2np(target.data[0])
        crop_size = args.scale
        cropped_sr_img = utils.shave(sr_img, crop_size)
        cropped_gt_img = utils.shave(gt_img, crop_size)
        if is_y is True:
            im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
            im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
        else:
            im_label = cropped_gt_img
            im_pre = cropped_sr_img
        avg_psnr += utils.compute_psnr(im_pre, im_label)

    print("===> Valid. psnr: {:.4f}".format(avg_psnr /
                                            len(testing_data_loader)))
Esempio n. 7
0
                            FLAGS,
                            reuse=False)
    # net_val = SloMo_model(data_val.frame0, data_val.frame1, data_val.frameT, FLAGS, reuse=True)

    print('Finish building the network!!!')

    # Convert back to uint8
    frame0 = tf.image.convert_image_dtype(data_train.frame0, dtype=tf.uint8)
    frame1 = tf.image.convert_image_dtype(data_train.frame1, dtype=tf.uint8)
    frameT = tf.image.convert_image_dtype(data_train.frameT, dtype=tf.uint8)
    pred_frameT = tf.image.convert_image_dtype(net_train.pred_frameT,
                                               dtype=tf.uint8)

    # Compute PSNR
    with tf.name_scope("compute_psnr"):
        psnr = compute_psnr(frameT, pred_frameT)

    # Add Image summaries
    with tf.name_scope("input_summaries"):
        tf.summary.image("frame0", frame0)
        tf.summary.image("frame1", frame1)

    with tf.name_scope("target_summary"):
        tf.summary.image("frameT", frameT)

    with tf.name_scope("output_summaries"):
        tf.summary.image("predicted_frameT", pred_frameT)
        tf.summary.image("FTO", tf.expand_dims(net_train.Ft0[:, :, :, 0],
                                               -1))  # showing only one channel
        tf.summary.image("FT1", tf.expand_dims(net_train.Ft1[:, :, :, 0],
                                               -1))  # showing only one channel
Esempio n. 8
0
def main():
    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(extension_module,
                                device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("#GPU Count: ", comm.n_procs)

    data_iterator_train = jsi_iterator(conf.batch_size, conf, train=True)
    if conf.scaling_factor == 1:
        d_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)
        l_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)

    else:
        d_t = nn.Variable((conf.batch_size, 160 / conf.scaling_factor,
                           160 / conf.scaling_factor, 3),
                          need_grad=True)
        l_t = nn.Variable((conf.batch_size, 160, 160, 3), need_grad=True)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    monitor_path = './nnmonitor' + \
        str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))

    monitor = Monitor(monitor_path)
    jsi_monitor = setup_monitor(conf, monitor)

    with nn.parameter_scope("jsinet"):
        nn.load_parameters(conf.pre_trained_model)
        net = model(d_t, conf.scaling_factor)
        net.pred.persistent = True
    rec_loss = F.mean(F.squared_error(net.pred, l_t))
    rec_loss.persistent = True
    g_final_loss = rec_loss

    if conf.jsigan:
        net_gan = gan_model(l_t, net.pred, conf)
        d_final_fm_loss = net_gan.d_adv_loss
        d_final_fm_loss.persistent = True
        d_final_detail_loss = net_gan.d_detail_adv_loss
        d_final_detail_loss.persistent = True
        g_final_loss = conf.rec_lambda * rec_loss + conf.adv_lambda * (
            net_gan.g_adv_loss + net_gan.g_detail_adv_loss
        ) + conf.fm_lambda * (net_gan.fm_loss + net_gan.fm_detail_loss)
        g_final_loss.persistent = True

    max_iter = data_iterator_train._size // (conf.batch_size)
    if comm.rank == 0:
        print("max_iter", data_iterator_train._size, max_iter)

    iteration = 0
    if not conf.jsigan:
        start_epoch = 0
        end_epoch = conf.adv_weight_point
        lr = conf.learning_rate * comm.n_procs
    else:
        start_epoch = conf.adv_weight_point
        end_epoch = conf.epoch
        lr = conf.learning_rate * comm.n_procs
        w_d = conf.weight_decay * comm.n_procs

    # Set generator parameters
    with nn.parameter_scope("jsinet"):
        solver_jsinet = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_jsinet.set_parameters(nn.get_parameters())

    if conf.jsigan:
        solver_disc_fm = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_disc_detail = S.Adam(alpha=lr,
                                    beta1=0.9,
                                    beta2=0.999,
                                    eps=1e-08)
        with nn.parameter_scope("Discriminator_FM"):
            solver_disc_fm.set_parameters(nn.get_parameters())
        with nn.parameter_scope("Discriminator_Detail"):
            solver_disc_detail.set_parameters(nn.get_parameters())

    for epoch in range(start_epoch, end_epoch):
        for index in range(max_iter):
            d_t.d, l_t.d = data_iterator_train.next()

            if not conf.jsigan:
                # JSI-net -> Generator
                lr_stair_decay_points = [200, 225]
                lr_net = get_learning_rate(lr, iteration,
                                           lr_stair_decay_points,
                                           conf.lr_decreasing_factor)
                g_final_loss.forward(clear_no_need_grad=True)
                solver_jsinet.zero_grad()
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_net)
                solver_jsinet.update()
            else:
                # GAN part (discriminator + generator)
                lr_gan = lr if epoch < conf.gan_lr_linear_decay_point \
                    else lr * (end_epoch - epoch) / (end_epoch - conf.gan_lr_linear_decay_point)
                lr_gan = lr_gan * conf.gan_ratio

                net.pred.need_grad = False

                # Discriminator_FM
                solver_disc_fm.zero_grad()
                d_final_fm_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_fm_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_fm_loss.backward(clear_buffer=True)
                solver_disc_fm.set_learning_rate(lr_gan)
                solver_disc_fm.weight_decay(w_d)
                solver_disc_fm.update()

                # Discriminator_Detail
                solver_disc_detail.zero_grad()
                d_final_detail_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_detail_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_detail_loss.backward(clear_buffer=True)
                solver_disc_detail.set_learning_rate(lr_gan)
                solver_disc_detail.weight_decay(w_d)
                solver_disc_detail.update()

                # Generator
                net.pred.need_grad = True
                solver_jsinet.zero_grad()
                g_final_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_gan)
                solver_jsinet.update()

            iteration += 1
            if comm.rank == 0:
                train_psnr = compute_psnr(net.pred.d, l_t.d, 1.)
                jsi_monitor['psnr'].add(iteration, train_psnr)
                jsi_monitor['rec_loss'].add(iteration, rec_loss.d.copy())
                jsi_monitor['time'].add(iteration)

            if comm.rank == 0:
                if conf.jsigan:
                    jsi_monitor['g_final_loss'].add(iteration,
                                                    g_final_loss.d.copy())
                    jsi_monitor['g_adv_loss'].add(iteration,
                                                  net_gan.g_adv_loss.d.copy())
                    jsi_monitor['g_detail_adv_loss'].add(
                        iteration, net_gan.g_detail_adv_loss.d.copy())
                    jsi_monitor['d_final_fm_loss'].add(
                        iteration, d_final_fm_loss.d.copy())
                    jsi_monitor['d_final_detail_loss'].add(
                        iteration, d_final_detail_loss.d.copy())
                    jsi_monitor['fm_loss'].add(iteration,
                                               net_gan.fm_loss.d.copy())
                    jsi_monitor['fm_detail_loss'].add(
                        iteration, net_gan.fm_detail_loss.d.copy())
                    jsi_monitor['lr'].add(iteration, lr_gan)

        if comm.rank == 0:
            if not os.path.exists(conf.output_dir):
                os.makedirs(conf.output_dir)
            with nn.parameter_scope("jsinet"):
                nn.save_parameters(
                    os.path.join(conf.output_dir,
                                 "model_param_%04d.h5" % epoch))
Esempio n. 9
0
        out_img_s = out_s.detach().numpy().squeeze()
        out_img_s = utils.convert_shape(out_img_s)

        out_img_p = out_p.detach().numpy().squeeze()
        out_img_p = utils.convert_shape(out_img_p)

    if opt.isHR:
        if opt.only_y is True:
            im_label = utils.quantize(sc.rgb2ycbcr(im_gt)[:, :, 0])
            im_pre = utils.quantize(sc.rgb2ycbcr(out_img_c)[:, :, 0])
        else:
            im_label = im_gt
            im_pre = out_img_c

        psnr_sr[i] = utils.compute_psnr(
            utils.shave(im_label, opt.upscale_factor),
            utils.shave(im_pre, opt.upscale_factor))
        ssim_sr[i] = utils.compute_ssim(
            utils.shave(im_label, opt.upscale_factor),
            utils.shave(im_pre, opt.upscale_factor))
    i += 1

    output_c_folder = os.path.join(
        opt.output_folder,
        imname.split('/')[-1].split('.')[0] + '_c.png')
    output_s_folder = os.path.join(
        opt.output_folder,
        imname.split('/')[-1].split('.')[0] + '_s.png')
    output_p_folder = os.path.join(
        opt.output_folder,
        imname.split('/')[-1].split('.')[0] + '_p.png')
Esempio n. 10
0
def main(args):
    cfg = cfg_dict[args.cfg_name]
    writer = SummaryWriter(os.path.join("runs", args.cfg_name))
    train_loader = get_data_loader(cfg, cfg["train_dir"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EDSR(cfg).to(device)
    criterion = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["init_lr"],
                                 betas=(0.9, 0.999), eps=1e-8)

    global_batches = 0
    if args.train:
        for epoch in range(cfg["n_epoch"]):
            model.train()
            running_loss = 0.0
            for i, batch in enumerate(train_loader):
                lr, hr = batch[0].to(device), batch[1].to(device)
                optimizer.zero_grad()
                sr = model(lr)
                loss = model.loss(sr, hr)
                # loss = criterion(model(lr), hr)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
                global_batches += 1
                if global_batches % cfg["lr_decay_every"] == 0:
                    for param_group in optimizer.param_groups:
                        print(f"decay lr to {param_group['lr'] / 10}")
                        param_group["lr"] /= 10

            if epoch % args.log_every == 0:
                model.eval()
                with torch.no_grad():
                    batch_samples = {"lr": batch[0], "hr": batch[1], 
                                     "sr": sr.cpu()}
                    writer.add_scalar("training-loss", 
                                      running_loss / len(train_loader),
                                      global_step=global_batches)
                    writer.add_scalar("PSNR", compute_psnr(batch_samples), 
                                      global_step=global_batches)
                    samples = {k: v[:3] for k, v in batch_samples.items()}
                    fig = visualize_samples(samples, f"epoch-{epoch}")
                    writer.add_figure("sample-visualization", fig, 
                                      global_step=global_batches)

            if epoch % args.save_every == 0:
                state = {"net": model.state_dict(), 
                         "optim": optimizer.state_dict()}
                checkpoint_dir = args.checkpoint_dir
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                path = os.path.join(checkpoint_dir, args.cfg_name)
                torch.save(state, path)
    
    # eval
    if args.eval:
        assert args.model_path and args.lr_img_path
        print(f"evaluating {args.lr_img_path}")
        state = torch.load(args.model_path, map_location=device)
        model.load_state_dict(state["net"])
        optimizer.load_state_dict(state["optim"])

        with torch.no_grad():
            lr = img2tensor(args.lr_img_path)
            sr = model(lr.clone().to(device)).cpu()
            samples = {"lr": lr, "sr": sr}
            if args.hr_img_path:
                samples["hr"] = img2tensor(args.hr_img_path)
                print(f"PSNR: {compute_psnr(samples)}")
            directory = os.path.dirname(args.lr_img_path)
            name = f"eval-{args.cfg_name}-{args.lr_img_path.split('/')[-1]}"
            visualize_samples(samples, name, save=True, 
                              directory=directory, size=6)
Esempio n. 11
0
def main():
    args = get_args()

    writer = SummaryWriter(args.work_dir)
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(args.work_dir, '{}.log'.format(timestamp))
    logger = get_root_logger(log_file)

    train_dataset = Dataset(dataset=args.train_dataset,
                            split='train',
                            crop_cfg=dict(type='random',
                                          patch_size=args.patch_size),
                            flip_and_rotate=True)
    val_dataset = Dataset(dataset=args.valid_dataset,
                          split='valid',
                          override_length=args.num_valids,
                          crop_cfg=None)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    model = EDSR(num_blocks=args.num_blocks, channels=args.num_channels)
    loss_fn = tf.keras.losses.MeanAbsoluteError()
    optimizer = tf.keras.optimizers.Adam(args.learning_rate)

    best_psnr = 0

    for epoch in range(1, args.num_epochs + 1):
        losses = []
        for lr, hr in tqdm(train_loader):
            lr = tf.constant(lr, dtype=tf.float32)
            hr = tf.constant(hr, dtype=tf.float32)
            with tf.GradientTape() as tape:
                sr = model(lr)
                loss = loss_fn(hr, sr)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))

            losses.append(loss.numpy())
        logger.info(f'Epoch {epoch} - loss: {np.mean(losses)}')
        writer.add_scalar('loss', np.mean(losses), epoch)

        # eval
        if epoch % args.eval_freq == 0 or epoch == args.num_epochs:
            logger.info('Evaluating...')
            psnrs = []
            for i, (lr, hr) in enumerate(val_loader):
                lr = tf.constant(lr, dtype=tf.float32)
                hr = tf.constant(hr, dtype=tf.float32)
                sr = model(lr)
                cur_psnr = compute_psnr(sr, hr)
                psnrs.append(cur_psnr)
                update_tfboard(writer, i, lr, hr, sr, epoch)
            psnr = np.mean(psnrs)
            if psnr > best_psnr:
                best_psnr = psnr
            model.save_weights(osp.join(args.work_dir, f'epoch_{epoch}'))
            logger.info('psnr: {:.2f} (best={:.2f})'.format(psnr, best_psnr))
            writer.add_scalar('psnr', psnr, epoch)
            writer.flush()
Esempio n. 12
0
def inference():
    """
    Inference function to generate high resolution hdr images
    """
    conf = get_config()
    ctx = get_extension_context(conf.nnabla_context.context,
                                device_id=conf.nnabla_context.device_id)
    nn.set_default_context(ctx)

    data, target = read_mat_file(conf.data.lr_sdr_test,
                                 conf.data.hr_hdr_test,
                                 conf.data.d_name_test,
                                 conf.data.l_name_test,
                                 train=False)

    if not os.path.exists(conf.test_img_dir):
        os.makedirs(conf.test_img_dir)

    data_sz = data.shape
    target_sz = target.shape
    PATCH_BOUNDARY = 10  # set patch boundary to reduce edge effect around patch edges
    test_loss_PSNR_list_for_epoch = []
    inf_time = []
    start_time = time.time()

    test_pred_full = np.zeros((target_sz[1], target_sz[2], target_sz[3]))

    print("Loading pre trained model.........", conf.pre_trained_model)
    nn.load_parameters(conf.pre_trained_model)

    for index in range(data_sz[0]):
        ###======== Divide Into Patches ========###
        for p in range(conf.test_patch**2):
            pH = p // conf.test_patch
            pW = p % conf.test_patch
            sH = data_sz[1] // conf.test_patch
            sW = data_sz[2] // conf.test_patch
            H_low_ind, H_high_ind, W_low_ind, W_high_ind = \
                get_hw_boundary(
                    PATCH_BOUNDARY, data_sz[1], data_sz[2], pH, sH, pW, sW)
            data_test_p = nn.Variable.from_numpy_array(
                data.d[index, H_low_ind:H_high_ind, W_low_ind:W_high_ind, :])
            data_test_sz = data_test_p.shape
            data_test_p = F.reshape(
                data_test_p,
                (1, data_test_sz[0], data_test_sz[1], data_test_sz[2]))
            st = time.time()
            net = model(data_test_p, conf.scaling_factor)
            net.pred.forward()
            test_pred_temp = net.pred.d
            inf_time.append(time.time() - st)
            test_pred_t = trim_patch_boundary(test_pred_temp, PATCH_BOUNDARY,
                                              data_sz[1], data_sz[2], pH, sH,
                                              pW, sW, conf.scaling_factor)
            #pred_sz = test_pred_t.shape
            test_pred_t = np.squeeze(test_pred_t)
            test_pred_full[pH * sH * conf.scaling_factor:(pH + 1) * sH *
                           conf.scaling_factor,
                           pW * sW * conf.scaling_factor:(pW + 1) * sW *
                           conf.scaling_factor, :] = test_pred_t

        ###======== Compute PSNR & Print Results========###
        test_GT = np.squeeze(target.d[index, :, :, :])
        test_PSNR = compute_psnr(test_pred_full, test_GT, 1.)
        test_loss_PSNR_list_for_epoch.append(test_PSNR)
        print(
            " <Test> [%4d/%4d]-th images, time: %4.4f(minutes), test_PSNR: %.8f[dB]  "
            % (int(index), int(data_sz[0]),
               (time.time() - start_time) / 60, test_PSNR))
        if conf.save_images:
            # comment for faster testing
            save_results_yuv(test_pred_full, index, conf.test_img_dir)
    test_PSNR_per_epoch = np.mean(test_loss_PSNR_list_for_epoch)

    print("######### Average Test PSNR: %.8f[dB]  #########" %
          (test_PSNR_per_epoch))
    print(
        "######### Estimated Inference Time (per 4K frame): %.8f[s]  #########"
        % (np.mean(inf_time) * conf.test_patch * conf.test_patch))
Esempio n. 13
0
def train(args):

    print('Number of GPUs available: ' + str(torch.cuda.device_count()))
    model = nn.DataParallel(CAEP(num_resblocks).cuda())
    print('Done Setup Model.')

    dataset = BSDS500Crop128(args.dataset_path)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=args.shuffle,
                            num_workers=args.num_workers)
    testset = Kodak(args.testset_path)
    testloader = DataLoader(testset,
                            batch_size=testset.__len__(),
                            num_workers=args.num_workers)
    print(
        f"Done Setup Training DataLoader: {len(dataloader)} batches of size {args.batch_size}"
    )
    print(f"Done Setup Testing DataLoader: {len(testset)} Images")

    MSE = nn.MSELoss()
    SSIM = pytorch_msssim.SSIM().cuda()
    MSSSIM = pytorch_msssim.MSSSIM().cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=1e-10)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=10,
        verbose=True,
    )

    writer = SummaryWriter(log_dir=f'TBXLog/{args.exp_name}')

    # ADMM variables
    Z = torch.zeros(16, 32, 32).cuda()
    U = torch.zeros(16, 32, 32).cuda()
    Z.requires_grad = False
    U.requires_grad = False

    if args.load != '':
        pretrained_state_dict = torch.load(f"./chkpt/{args.load}/model.state")
        current_state_dict = model.state_dict()
        current_state_dict.update(pretrained_state_dict)
        model.load_state_dict(current_state_dict)
        # Z = torch.load(f"./chkpt/{args.load}/Z.state")
        # U = torch.load(f"./chkpt/{args.load}/U.state")
        if args.load == args.exp_name:
            optimizer.load_state_dict(
                torch.load(f"./chkpt/{args.load}/opt.state"))
            scheduler.load_state_dict(
                torch.load(f"./chkpt/{args.load}/lr.state"))
        print('Model Params Loaded.')

    model.train()

    for ei in range(args.res_epoch + 1, args.res_epoch + args.num_epochs + 1):
        # train
        train_loss = 0
        train_ssim = 0
        train_msssim = 0
        train_psnr = 0
        train_peanalty = 0
        train_bpp = 0
        avg_c = torch.zeros(16, 32, 32).cuda()
        avg_c.requires_grad = False

        for bi, crop in enumerate(dataloader):
            x = crop.cuda()
            y, c = model(x)

            psnr = compute_psnr(x, y)
            mse = MSE(y, x)
            ssim = SSIM(x, y)
            msssim = MSSSIM(x, y)

            mix = 1000 * (1 - msssim) + 1000 * (1 - ssim) + 1e4 * mse + (45 -
                                                                         psnr)
            # ADMM Step 1
            peanalty = rho / 2 * torch.norm(c - Z + U, 2)
            bpp = compute_bpp(c, x.shape[0], 'crop', save=False)

            avg_c += torch.mean(c.detach() /
                                (len(dataloader) * args.admm_every),
                                dim=0)

            loss = mix + peanalty
            if ei == 1 and args.load != args.exp_name:
                loss = 1e5 * mse  # warm up

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(
                '[%3d/%3d][%5d/%5d] Loss: %f, SSIM: %f, MSSSIM: %f, PSNR: %f, Norm of Code: %f, BPP: %2f'
                % (ei, args.num_epochs + args.res_epoch, bi, len(dataloader),
                   loss, ssim, msssim, psnr, peanalty, bpp))
            writer.add_scalar('batch_train/loss', loss,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/ssim', ssim,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/msssim', msssim,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/psnr', psnr,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/norm', peanalty,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/bpp', bpp,
                              ei * len(dataloader) + bi)

            train_loss += loss.item() / len(dataloader)
            train_ssim += ssim.item() / len(dataloader)
            train_msssim += msssim.item() / len(dataloader)
            train_psnr += psnr.item() / len(dataloader)
            train_peanalty += peanalty.item() / len(dataloader)
            train_bpp += bpp / len(dataloader)

        writer.add_scalar('epoch_train/loss', train_loss, ei)
        writer.add_scalar('epoch_train/ssim', train_ssim, ei)
        writer.add_scalar('epoch_train/msssim', train_msssim, ei)
        writer.add_scalar('epoch_train/psnr', train_psnr, ei)
        writer.add_scalar('epoch_train/norm', train_peanalty, ei)
        writer.add_scalar('epoch_train/bpp', train_bpp, ei)

        if ei % args.admm_every == args.admm_every - 1:
            # ADMM Step 2
            Z = (avg_c + U).masked_fill_((torch.Tensor(
                np.argsort((avg_c + U).data.cpu().numpy(), axis=None)) >= int(
                    (1 - pruning_ratio) * 16 * 32 * 32)).view(16, 32,
                                                              32).cuda(),
                                         value=0)
            # ADMM Step 3
            U += avg_c - Z

        # test
        model.eval()
        val_loss = 0
        val_ssim = 0
        val_msssim = 0
        val_psnr = 0
        val_peanalty = 0
        val_bpp = 0
        for bi, (img, patches, _) in enumerate(testloader):
            avg_loss = 0
            avg_ssim = 0
            avg_msssim = 0
            avg_psnr = 0
            avg_peanalty = 0
            avg_bpp = 0
            for i in range(6):
                for j in range(4):
                    x = torch.Tensor(patches[:, i, j, :, :, :]).cuda()
                    y, c = model(x)

                    psnr = compute_psnr(x, y)
                    mse = MSE(y, x)
                    ssim = SSIM(x, y)
                    msssim = MSSSIM(x, y)

                    mix = 1000 * (1 - msssim) + 1000 * (
                        1 - ssim) + 1e4 * mse + (45 - psnr)

                    peanalty = rho / 2 * torch.norm(c - Z + U, 2)
                    bpp = compute_bpp(c,
                                      x.shape[0],
                                      f'Kodak_patches_{i}_{j}',
                                      save=True)
                    loss = mix + peanalty

                    avg_loss += loss.item() / 24
                    avg_ssim += ssim.item() / 24
                    avg_msssim += msssim.item() / 24
                    avg_psnr += psnr.item() / 24
                    avg_peanalty += peanalty.item() / 24
                    avg_bpp += bpp / 24

            save_kodak_img(model, img, 0, patches, writer, ei)
            save_kodak_img(model, img, 10, patches, writer, ei)
            save_kodak_img(model, img, 20, patches, writer, ei)

            val_loss += avg_loss
            val_ssim += avg_ssim
            val_msssim += avg_msssim
            val_psnr += avg_psnr
            val_peanalty += avg_peanalty
            val_bpp += avg_bpp
        print(
            '*Kodak: [%3d/%3d] Loss: %f, SSIM: %f, MSSSIM: %f, Norm of Code: %f, BPP: %.2f'
            % (ei, args.num_epochs + args.res_epoch, val_loss, val_ssim,
               val_msssim, val_peanalty, val_bpp))

        # bz = call('tar -jcvf ./code/code.tar.bz ./code', shell=True)
        # total_code_size = os.stat('./code/code.tar.bz').st_size
        # total_bpp = total_code_size * 8 / 24 / 768 / 512

        writer.add_scalar('test/loss', val_loss, ei)
        writer.add_scalar('test/ssim', val_ssim, ei)
        writer.add_scalar('test/msssim', val_msssim, ei)
        writer.add_scalar('test/psnr', val_psnr, ei)
        writer.add_scalar('test/norm', val_peanalty, ei)
        writer.add_scalar('test/bpp', val_bpp, ei)
        # writer.add_scalar('test/total_bpp', total_bpp, ei)
        model.train()

        scheduler.step(train_loss)

        # save model
        if ei % args.save_every == args.save_every - 1:
            torch.save(model.state_dict(),
                       f"./chkpt/{args.exp_name}/model.state")
            torch.save(optimizer.state_dict(),
                       f"./chkpt/{args.exp_name}/opt.state")
            torch.save(scheduler.state_dict(),
                       f"./chkpt/{args.exp_name}/lr.state")
            torch.save(Z, f"./chkpt/{args.exp_name}/Z.state")
            torch.save(U, f"./chkpt/{args.exp_name}/U.state")

    writer.close()
Esempio n. 14
0
    def test_png(self):
        # saver to save model
        self.saver = tf.train.Saver()
        tf.global_variables_initializer().run()
        # restore check-point
        self.load(self.checkpoint_dir)  # for testing JSI-GAN
        # self.load_pretrained_model(self.checkpoint_dir, 'JSInet')  # for testing JSInet

        """" Test """
        data_path_test = glob.glob(os.path.join(self.test_data_path_LR_SDR, '*.png'))
        label_path_test = glob.glob(os.path.join(self.test_data_path_HR_HDR, '*.png'))

        """ Make "test_img_dir" per experiment """
        test_img_dir = os.path.join(self.test_img_dir, self.model_dir)
        if not os.path.exists(test_img_dir):
            os.makedirs(test_img_dir)
        """ Testing """
        patch_boundary = 10  # set patch boundary to reduce edge effect around patch edges
        test_loss_PSNR_list_for_epoch = []
        inf_time = []
        start_time = time.time()
        for index in range(len(data_path_test)//3):
            ###======== Read Data ========###
            y = np.array(Image.open(data_path_test[3*index+2]))
            u = np.array(Image.open(data_path_test[3*index]))
            v = np.array(Image.open(data_path_test[3*index+1]))
            ###======== Pre-process Data ========###
            img = np.expand_dims(np.stack([y, u, v], axis=2), axis=0)
            data_sz = img.shape
            test_pred_full = np.zeros((data_sz[1]*self.scale_factor, data_sz[2]*self.scale_factor, data_sz[3]))
            img = np.array(img, dtype=np.double) / 255.
            data_test = np.clip(img, 0, 1)
            ###======== Divide Into Patches ========###
            for p in range(self.test_patch[0] * self.test_patch[1]):
                pH = p // self.test_patch[1]
                pW = p % self.test_patch[1]
                sH = data_sz[1] // self.test_patch[0]
                sW = data_sz[2] // self.test_patch[1]
                # process data considering patch boundary
                H_low_ind, H_high_ind, W_low_ind, W_high_ind = \
                    get_HW_boundary(patch_boundary, data_sz[1], data_sz[2], pH, sH, pW, sW)
                data_test_p = data_test[:, H_low_ind: H_high_ind, W_low_ind: W_high_ind, :]
                ###======== Run Session ========###
                st = time.time()
                test_pred_o = self.sess.run(self.test_pred, feed_dict={self.test_input_ph: data_test_p})
                inf_time.append(time.time() - st)
                # trim patch boundary
                test_pred_t = trim_patch_boundary(test_pred_o, patch_boundary, data_sz[1], data_sz[2], pH, sH, pW, sW, self.scale_factor)
                # store in pred_full
                test_pred_full[pH * sH * self.scale_factor: (pH + 1) * sH * self.scale_factor,
                pW * sW * self.scale_factor: (pW + 1) * sW * self.scale_factor, :] = np.squeeze(test_pred_t)
            ###======== Compute PSNR & Print Results========###
            label_y = np.array(Image.open(label_path_test[3*index+2]))
            label_u = np.array(Image.open(label_path_test[3*index]))
            label_v = np.array(Image.open(label_path_test[3*index+1]))
            test_GT = np.stack([label_y, label_u, label_v], axis=2)
            test_GT = np.array(test_GT, dtype=np.double) / 1023.
            test_GT = np.clip(test_GT, 0, 1)
            test_PSNR = utils.compute_psnr(test_pred_full, test_GT, 1.)
            test_loss_PSNR_list_for_epoch.append(test_PSNR)
            print(" <Test> [%4d/%4d]-th images, time: %4.4f(minutes), test_PSNR: %.8f[dB]  "
                  % (int(index), int(len(data_path_test)//3), (time.time() - start_time) / 60, test_PSNR))

            ###======== Save Predictions as Images ========###
            utils.save_results_yuv(test_pred_full, index, test_img_dir)  # comment for faster testing
        test_PSNR_per_epoch = np.mean(test_loss_PSNR_list_for_epoch)

        print("######### Average Test PSNR: %.8f[dB]  #########" % (test_PSNR_per_epoch))
        print("######### Estimated Inference Time (per 4K frame): %.8f[s]  #########" % (np.mean(inf_time)*self.test_patch[0]*self.test_patch[1]))
Esempio n. 15
0
    psnr_sum = 0.

    for gt_vid in training_generator:
        gt_vid = gt_vid.cuda()
        if not args.two_bucket:
            # b1 = c2b(gt_vid) # (N,1,H,W)
            b1 = torch.mean(gt_vid, dim=1, keepdim=True)
            # interm_vid = utils.impulse_inverse(b1, block_size=args.blocksize)
            # assert interm_vid.shape == gt_vid.shape
            highres_vid = uNet(b1)  # (N,16,H,W)
        else:
            b1, b0 = c2b(gt_vid)
            b_stack = torch.cat([b1, b0], dim=1)
            highres_vid = uNet(b_stack)

        psnr_sum += utils.compute_psnr(highres_vid, gt_vid).item()

        ## LOSSES
        final_loss = utils.weighted_L1loss(highres_vid, gt_vid)
        final_loss_sum += final_loss.item()

        tv_loss = utils.gradx(highres_vid).abs().mean() + utils.grady(
            highres_vid).abs().mean()
        tv_loss_sum += tv_loss.item()

        loss = final_loss + 0.1 * tv_loss
        loss_sum += loss.item()

        ## BACKPROP
        optimizer.zero_grad()
        loss.backward()
Esempio n. 16
0
def test_abs(n_test, n_batch, n_steps, alpha, u, x_test):

    x_test = np.expand_dims(x_test, axis=3)
    _, height, width, nc = x_test.shape

    device_id = 0
    torch.cuda.set_device(device_id)

    zeropad = nn.ZeroPad2d(height // 2)

    x_test = x_test[:n_test, :, :, :].reshape(-1, nc, height, width)

    N_iter = np.int(np.ceil(n_test / np.float(n_batch)))
    x_test_rec = np.zeros_like(x_test)

    eps_tensor = torch.cuda.FloatTensor([1e-15])
    epoch_idx = np.arange(n_test)

    pbar = tqdm(range(N_iter))
    for iters in pbar:

        x = x_test[epoch_idx[iters *
                             n_batch:np.min([(iters + 1) *
                                             n_batch, n_test])], :, :, :]
        x_gt = torch.cuda.FloatTensor(x).view(-1, nc, height, width).cuda()
        uk = torch.cuda.FloatTensor(u).view(-1, nc, height, width)

        # z = x + u
        z = zeropad(x_gt + uk)
        dummy_zeros = torch.zeros_like(z).cuda()
        z_complex = torch.cat((z.unsqueeze(4), dummy_zeros.unsqueeze(4)), 4)

        Fz = torch.fft(z_complex, 2, normalized=True)
        # y = |F(x+u)| = |Fz|
        y = torch.norm(Fz, dim=4)
        y_dual = torch.cat((y.unsqueeze(4), y.unsqueeze(4)), 4)

        x_est = x_test_rec[
            epoch_idx[iters * n_batch:np.min([(iters + 1) *
                                              n_batch, n_test])], :, :, :]
        x_est = torch.cuda.FloatTensor(x_est).cuda()

        # image loss and measurement loss
        loss_x_pr = []
        loss_y_pr = []
        for kx in range(n_steps):

            z_est = zeropad(x_est + uk + eps_tensor)
            z_est_complex = torch.cat(
                (z_est.unsqueeze(4), dummy_zeros.unsqueeze(4)), 4)
            Fz_est = torch.fft(z_est_complex, 2, normalized=True)
            y_est = torch.norm(Fz_est, dim=4)
            y_est_dual = torch.cat((y_est.unsqueeze(4), y_est.unsqueeze(4)), 4)
            # angle Fz
            Fz_est_phase = Fz_est / (y_est_dual + eps_tensor)
            # update x
            x_grad_complex = torch.ifft(Fz_est -
                                        torch.mul(Fz_est_phase, y_dual),
                                        2,
                                        normalized=True)
            x_grad = x_grad_complex[:, :, height // 2:height // 2 + height,
                                    width // 2:width // 2 + width, 0]
            x_est = x_est - alpha * x_grad
            x_est = torch.clamp(x_est, 0, 1)

            loss_x_pr.append(np.mean((x - x_est.cpu().detach().numpy())**2))
            loss_y_pr.append(height * 2 * width * 2 * np.mean(
                (y.cpu().detach().numpy().reshape(-1, 2 * height, 2 * width) -
                 np.abs(
                     np.fft.fft2(z_est.cpu().detach().numpy().reshape(
                         -1, 2 * height, 2 * width),
                                 norm="ortho")))**2))

        x_test_rec[epoch_idx[iters * n_batch:np.min([(
            iters +
            1) * n_batch, n_test])], :, :, :] = x_est.cpu().detach().numpy()

    mse_list = [
        compare_mse(x_test[i, 0, :, :], x_test_rec[i, 0, :, :])
        for i in range(n_test)
    ]
    psnr_list = [
        compute_psnr(x_test[i, 0, :, :], x_test_rec[i, 0, :, :])
        for i in range(n_test)
    ]
    ssim_list = [
        compare_ssim(x_test[i, 0, :, :], x_test_rec[i, 0, :, :])
        for i in range(n_test)
    ]
    print(f'mse {np.mean(mse_list):.2f}')
    print(f'psnr {np.mean(psnr_list):.2f}')
    print(f'ssim {np.mean(ssim_list):.2f}')

    mse = np.mean((x_test_rec - x_test)**2)
    psnr = 20 * np.log10((np.max(x_test) - np.min(x_test)) / np.sqrt(mse))
    print(f'mean mse {mse:.2f}')
    print(f'psnr of mean {psnr:.2f}')
    print(f'psnr of mean (mean of psnr) {psnr:.2f}({np.mean(psnr_list):.2f})')

    return x_test_rec, mse_list, psnr_list, ssim_list
Esempio n. 17
0
    loss_sum = 0.
    psnr_sum = 0.
    for gt_vid in training_generator:

        gt_vid = gt_vid.cuda()
        if not args.two_bucket:
            b1 = c2b(gt_vid)  # (N,1,H,W)
            # b1 = torch.mean(gt_vid, dim=1, keepdim=True)
            interm_vid = invNet(b1)
        else:
            b1, b0 = c2b(gt_vid)
            b_stack = torch.cat([b1, b0], dim=1)
            interm_vid = invNet(b_stack)
        highres_vid = uNet(interm_vid)  # (N,16,H,W)

        psnr_sum += utils.compute_psnr(highres_vid, gt_vid).item()

        ## LOSSES
        if args.intermediate:
            interm_loss = utils.weighted_L1loss(interm_vid, gt_vid)
            interm_loss_sum += interm_loss.item()

        final_loss = utils.weighted_L1loss(highres_vid, gt_vid)
        final_loss_sum += final_loss.item()

        tv_loss = utils.gradx(highres_vid).abs().mean() + utils.grady(
            highres_vid).abs().mean()
        tv_loss_sum += tv_loss.item()

        if args.intermediate:
            loss = final_loss + 0.1 * tv_loss + 0.5 * interm_loss
Esempio n. 18
0
def infer(data_path, model):
    psnr = utils.AvgrageMeter()
    ssim = utils.AvgrageMeter()
    times = utils.AvgrageMeter()

    model.eval()
    rgb2gray = torchvision.transforms.Grayscale(1)
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])

    with torch.no_grad():
        for step, pt in enumerate(glob.glob(data_path)):
            image = utils.crop_img(np.array(rgb2gray(Image.open(pt)))[..., np.newaxis])
            noise_map = np.random.randn(*(image.shape))*args.sigma
            noise_img = np.clip(image+noise_map,0,255).astype(np.uint8)

            # # Test on whole image
            # input = transforms(noise_img).unsqueeze(dim=0).cuda()
            # target = transforms(image).unsqueeze(dim=0).cuda()
            # begin_time = time.time()
            # logits = model(input)
            # end_time = time.time()
            # n = input.size(0)

            # Test on whole image with data augmentation
            target = transforms(image).unsqueeze(dim=0).cuda()
            for i in range(8):
                im = utils.data_augmentation(noise_img,i)
                input = transforms(im.copy()).unsqueeze(dim=0).cuda()
                begin_time = time.time()
                if i == 0:
                    logits = utils.inverse_augmentation(model(input).cpu().numpy().transpose(0,2,3,1)[0],i)
                else:
                    logits = logits + utils.inverse_augmentation(model(input).cpu().numpy().transpose(0,2,3,1)[0],i)
                end_time = time.time()
            n = input.size(0)
            logits = transforms(logits/8).unsqueeze(dim=0).cuda()

            # # Test on patches2patches
            # noise_patches = utils.slice_image2patches(noise_img, patch_size=64)
            # image_patches = utils.slice_image2patches(image, patch_size=64)
            # input = torch.tensor(noise_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda()
            # target = torch.tensor(image_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda()
            # begin_time = time.time()
            # logits = model(input)
            # end_time = time.time()
            # n = input.size(0)

            s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target)
            p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy())
            t = end_time-begin_time
            psnr.update(p, n)
            ssim.update(s, n)
            times.update(t,n)
            print('psnr:%6f ssim:%6f time:%6f' % (p, s, t))
            
            # Image.fromarray(noise_img[...,0]).save(args.save+'/'+str(step)+'_noise.png')
            # Image.fromarray(np.clip(logits[0,0].cpu().numpy()*255, 0, 255).astype(np.uint8)).save(args.save+'/'+str(step)+'_denoised.png')

    return psnr.avg, ssim.avg, times.avg
Esempio n. 19
0
    
    net = SRNDeblurNet().cuda()
    set_requires_grad(net,False)
    last_epoch = load_model( net , args.resume , epoch = args.resume_epoch  ) 
    
    log_dir = '{}/test/{}'.format(args.resume,args.dataset)
    os.system('mkdir -p {}'.format(log_dir) )
    psnr_list = []

    tt = time()
    with torch.no_grad():
        for step , batch in tqdm(enumerate( dataloader ) , total = len(dataloader) ):
            for k in batch:
                batch[k] = batch[k].cuda(async = True)
                batch[k].requires_grad = False

            y256 , y128 , y64 = net( batch['img256'] , batch['img128'] , batch['img64'] )
            if step==0:
                print(y256.shape)
            psnr_list.append( compute_psnr(y256 ,  batch['label256'], 2 ).cpu().numpy() )
            if step % 100 == 100 -1 :
                t = time()
                psnr = np.mean( psnr_list )
                tqdm.write("{} / {} : psnr {} , {} img/s".format( step , len(dataloader) - 1 , psnr , 100*args.batch_size / (t-tt)   ) )
                tt = t
    psnr = np.mean( psnr_list )
    print( psnr )

    with open('{}/psnr.txt'.format(log_dir),'a') as log_fp:
        log_fp.write( 'epoch {} : psnr {}'.format( last_epoch , psnr ) )
Esempio n. 20
0
            torch.cuda.synchronize()
            time_list[i] = start.elapsed_time(end)  # milliseconds
        else:
            start.record()
            out = crop_forward(im_input, model)
            end.record()
            torch.cuda.synchronize()
            time_list[i] = start.elapsed_time(end)  # milliseconds

    sr_img = utils.tensor2np(out.detach()[0])
    if opt.is_y is True:
        im_label = utils.quantize(sc.rgb2ycbcr(im_gt)[:, :, 0])
        im_pre = utils.quantize(sc.rgb2ycbcr(sr_img)[:, :, 0])
    else:
        im_label = im_gt
        im_pre = sr_img
    psnr_list[i] = utils.compute_psnr(im_pre, im_label)
    ssim_list[i] = utils.compute_ssim(im_pre, im_label)

    output_folder = os.path.join(opt.output_folder, imname.split('/')[-1])

    if not os.path.exists(opt.output_folder):
        os.makedirs(opt.output_folder)

    sio.imsave(output_folder, sr_img)
    i += 1

print("Mean PSNR: {}, SSIM: {}, Time: {} ms".format(np.mean(psnr_list),
                                                    np.mean(ssim_list),
                                                    np.mean(time_list)))
Esempio n. 21
0
    def test_mat(self):
        # saver to save model
        self.saver = tf.train.Saver()
        tf.global_variables_initializer().run()
        # restore check-point
        self.load(self.checkpoint_dir)  # for testing JSI-GAN
        # self.load_pretrained_model(self.checkpoint_dir, 'JSInet')  # for testing JSInet

        """" Test """
        """ Matlab data for test """
        data_path_test = self.test_data_path_LR_SDR
        label_path_test = self.test_data_path_HR_HDR
        data_test, label_test = read_mat_file(data_path_test, label_path_test, 'SDR_YUV', 'HDR_YUV')
        data_sz = data_test.shape
        label_sz = label_test.shape

        """ Make "test_img_dir" per experiment """
        test_img_dir = os.path.join(self.test_img_dir, self.model_dir)
        if not os.path.exists(test_img_dir):
            os.makedirs(test_img_dir)

        """ Testing """
        patch_boundary = 10  # set patch boundary to reduce edge effect around patch edges
        test_loss_PSNR_list_for_epoch = []
        inf_time = []
        start_time = time.time()
        test_pred_full = np.zeros((label_sz[1], label_sz[2], label_sz[3]))
        for index in range(data_sz[0]):
            ###======== Divide Into Patches ========###
            for p in range(self.test_patch[0] * self.test_patch[1]):
                pH = p // self.test_patch[1]
                pW = p % self.test_patch[1]
                sH = data_sz[1] // self.test_patch[0]
                sW = data_sz[2] // self.test_patch[1]
                # process data considering patch boundary
                H_low_ind, H_high_ind, W_low_ind, W_high_ind = \
                    get_HW_boundary(patch_boundary, data_sz[1], data_sz[2], pH, sH, pW, sW)
                data_test_p = data_test[index, H_low_ind: H_high_ind, W_low_ind: W_high_ind, :]
                data_test_p = np.expand_dims(data_test_p, axis=0)
                ###======== Run Session ========###
                st = time.time()
                test_pred_o = self.sess.run(self.test_pred, feed_dict={self.test_input_ph: data_test_p})
                inf_time.append(time.time() - st)
                # trim patch boundary
                test_pred_t = trim_patch_boundary(test_pred_o, patch_boundary, data_sz[1], data_sz[2], pH, sH, pW, sW, self.scale_factor)
                # store in pred_full
                test_pred_full[pH * sH * self.scale_factor: (pH + 1) * sH * self.scale_factor,
                pW * sW * self.scale_factor: (pW + 1) * sW * self.scale_factor, :] = np.squeeze(test_pred_t)
            ###======== Compute PSNR & Print Results========###
            test_GT = np.squeeze(label_test[index, :, :, :])
            test_PSNR = utils.compute_psnr(test_pred_full, test_GT, 1.)
            test_loss_PSNR_list_for_epoch.append(test_PSNR)
            print(" <Test> [%4d/%4d]-th images, time: %4.4f(minutes), test_PSNR: %.8f[dB]  "
                  % (int(index), int(data_sz[0]), (time.time() - start_time) / 60, test_PSNR))

            ###======== Save Predictions as Images ========###
            utils.save_results_yuv(test_pred_full, index, test_img_dir)  # comment for faster testing
        test_PSNR_per_epoch = np.mean(test_loss_PSNR_list_for_epoch)

        print("######### Average Test PSNR: %.8f[dB]  #########" % (test_PSNR_per_epoch))
        print("######### Estimated Inference Time (per 4K frame): %.8f[s]  #########" % (np.mean(inf_time)*self.test_patch[0]*self.test_patch[1]))