Example #1
0
    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\nEvaluation:')
        self.model.eval()
        self.ckp.start_log(train=False)
        with torch.no_grad():
            tqdm_test = tqdm(self.loader_test, ncols=80)
            bic_PSNR = 0
            for idx_img, (lr, lre, hr, filename) in enumerate(tqdm_test):
                ycbcr_flag = False
                if self.args.n_colors == 1 and lr.size()[1] == 3:
                    # If n_colors is 1, split image into Y,Cb,Cr
                    ycbcr_flag = True
                    sr_cbcr = lre[:, 1:, :, :].to(self.device)
                    lre = lre[:, 0:1, :, :]
                    lr_cbcr = lr[:, 1:, :, :].to(self.device)
                    lr = lr[:, 0:1, :, :]
                    hr_cbcr = hr[:, 1:, :, :].to(self.device)
                    hr = hr[:, 0:1, :, :]

                filename = filename[0]
                lre = lre.to(self.device)
                lr = lr.to(self.device)
                hr = hr.to(self.device)
                sr = self.model(lr)
                PSNR = utils.calc_psnr(self.args, sr, hr)
                bic_PSNR += utils.calc_psnr(self.args, lre, hr)
                self.ckp.report_log(PSNR, train=False)
                lr, hr, sr = utils.postprocess(lr,
                                               hr,
                                               sr,
                                               rgb_range=self.args.rgb_range,
                                               ycbcr_flag=ycbcr_flag,
                                               device=self.device)

                if ycbcr_flag:
                    lr = torch.cat((lr, lr_cbcr), dim=1)
                    hr = torch.cat((hr, hr_cbcr), dim=1)
                    sr = torch.cat((sr, sr_cbcr), dim=1)

                save_list = [lr, hr, sr]
                if self.args.save_images:
                    self.ckp.save_images(filename, save_list, self.args.scale)

            self.ckp.end_log(len(self.loader_test), train=False)
            best = self.ckp.psnr_log.max(0)
            self.ckp.write_log(
                '[{}]\taverage PSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                    self.args.data_test, self.ckp.psnr_log[-1], best[0],
                    best[1] + 1))
            print('Bicubic PSNR: {:.3f}'.format(bic_PSNR /
                                                len(self.loader_test)))
            if not self.args.test_only:
                self.ckp.save(self, epoch, is_best=(best[1] + 1 == epoch))
Example #2
0
    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\nEvaluation:')
        self.model.eval()
        self.ckp.start_log(train=False)
        with torch.no_grad():
            tqdm_test = tqdm(self.loader_test, ncols=80)
            for idx_img, (lr, _, filename) in enumerate(tqdm_test):
                ycbcr_flag = False
                filename = filename[0][0]
                lr = lr.to(self.device)
                frame1, frame2 = lr[:, 0], lr[:, 1]
                if self.args.n_colors == 1 and lr.size()[-3] == 3:
                    ycbcr_flag = True
                    frame1_cbcr = frame1[:, 1:]
                    frame2_cbcr = frame2[:, 1:]
                    frame1 = frame1[:, 0:1]
                    frame2 = frame2[:, 0:1]

                frame2_compensated, flow = self.model(frame1, frame2)

                PSNR = utils.calc_psnr(self.args, frame1, frame2_compensated)
                self.ckp.report_log(PSNR, train=False)
                frame1, frame2, frame2c = utils.postprocess(
                    frame1,
                    frame2,
                    frame2_compensated,
                    rgb_range=self.args.rgb_range,
                    ycbcr_flag=ycbcr_flag,
                    device=self.device)

                if ycbcr_flag:
                    frame1 = torch.cat((frame1, frame1_cbcr), dim=1)
                    frame2 = torch.cat((frame2, frame2_cbcr), dim=1)
                    frame2_cbcr_c = F.grid_sample(frame2_cbcr,
                                                  flow.permute(0, 2, 3, 1),
                                                  padding_mode='border')
                    frame2c = torch.cat((frame2c, frame2_cbcr_c), dim=1)

                save_list = [frame1, frame2, frame2c]
                if self.args.save_images and idx_img % 10 == 0:
                    self.ckp.save_images(filename, save_list, self.args.scale)

            self.ckp.end_log(len(self.loader_test), train=False)
            best = self.ckp.psnr_log.max(0)
            self.ckp.write_log(
                '[{}]\taverage PSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                    self.args.data_test, self.ckp.psnr_log[-1], best[0],
                    best[1] + 1))
            if not self.args.test_only:
                self.ckp.save(self, epoch, is_best=(best[1] + 1 == epoch))
Example #3
0
    def validation_step(self, batch, batch_idx):

        predict, _ = self.student_model(batch[0])
        gt = batch[1]

        predict_numpy = torch2image(predict)[:, :, ::-1]
        gt_numpy = torch2image(gt)[:, :, ::-1]
        psnr = calc_psnr(predict_numpy, gt_numpy)
        # loss = 0.0
        # for criterion, weight, name in self.val_criterions:
        #     criterion_loss = criterion(predict, gt)
        #     loss += weight * criterion_loss
        #     self.log(name, criterion_loss)

        return {
            "origs": batch[0],
            'psnr': psnr,
            "gt": batch[1],
            "predict": predict
        }
Example #4
0
    def test_step(self, batch, batch_idx):
        hr_pred, _ = self.student_model(batch[0])
        hr_pred_teacher, _ = self.teacher_model(batch[0])
        hr_gt = batch[1]

        lr_numpy = torch2image(batch[0])[:, :, ::-1]
        hr_teacher_numpy = torch2image(hr_pred_teacher)[:, :, ::-1]
        hr_pred_numpy = torch2image(hr_pred)[:, :, ::-1]
        hr_gt_numpy = torch2image(hr_gt)[:, :, ::-1]

        psnr = calc_psnr(hr_gt_numpy, hr_pred_numpy)
        ssim = calc_ssim(hr_gt_numpy, hr_pred_numpy)

        if not os.path.exists("./results/"):
            os.makedirs("./results/")

        try:
            logger = self.logger.experiment[0]
            name = logger.get_key()
        except:
            name = "tmp"

        if not os.path.exists("./results/{}".format(name)):
            os.makedirs("./results/{}".format(name))

        lr_image = Image.fromarray(lr_numpy)
        hr_teacher = Image.fromarray(hr_teacher_numpy)
        img_pred = Image.fromarray(hr_pred_numpy)
        gt_pred = Image.fromarray(hr_gt_numpy)

        img_pred.save("./results/{}/{}_pred.png".format(name, str(batch_idx)),
                      format="PNG")
        hr_teacher.save("./results/{}/{}_teacher.png".format(
            name, str(batch_idx)),
                        format="PNG")
        gt_pred.save("./results/{}/{}_gt.png".format(name, str(batch_idx)),
                     format="PNG")
        lr_image.save("./results/{}/{}_lr.png".format(name, str(batch_idx)),
                      format="PNG")

        self.log_dict({"psnr": psnr, "ssim": ssim})
Example #5
0
    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((hr.width // args.scale, hr.height // args.scale),
                   resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale),
                        resample=pil_image.BICUBIC)
    bicubic.save(
        args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

    lr = np.expand_dims(
        np.array(lr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
    hr = np.expand_dims(
        np.array(hr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
    lr = torch.from_numpy(lr).to(device)
    hr = torch.from_numpy(hr).to(device)

    with torch.no_grad():
        preds = model(lr).squeeze(0)

    preds_y = convert_rgb_to_y(denormalize(preds), dim_order='chw')
    hr_y = convert_rgb_to_y(denormalize(hr.squeeze(0)), dim_order='chw')

    preds_y = preds_y[args.scale:-args.scale, args.scale:-args.scale]
    hr_y = hr_y[args.scale:-args.scale, args.scale:-args.scale]

    psnr = calc_psnr(hr_y, preds_y)
    print('PSNR: {:.2f}'.format(psnr))

    output = pil_image.fromarray(
        denormalize(preds).permute(1, 2, 0).byte().cpu().numpy())
    output.save(args.image_file.replace('.', '_rdn_x{}.'.format(args.scale)))
Example #6
0
        model.eval()
        epoch_psnr = AverageMeter()
        epoch_psnr1 = AverageMeter()
        epoch_psnr2 = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)  

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        for data in eval_dataloader1:
            inputs1, labels1 = data

            inputs1 = inputs1.to(device)
            labels1 = labels1.to(device)

            with torch.no_grad():
                preds1 = model(inputs1).clamp(0.0, 1.0)  

            epoch_psnr1.update(calc_psnr(preds1, labels1), len(inputs1))

        for data in eval_dataloader2:
            inputs2, labels2 = data
Example #7
0
def test(testloader, net, single=False, crop=0, ignore=0):

    L = CONFIG.DMUnetL
    oup_path = os.path.join(CONFIG.OUTPIC_FILE, CONFIG.Test_DataSet)
    if not os.path.exists(oup_path):
        os.mkdir(oup_path)

    psnr = [0 for i in range(L)]
    pic_psnr = [0 for i in range(L)]
    CPSNR = np.zeros((L, 4))
    # psnr = 0
    if single:
        from openpyxl import Workbook
        wkbook = Workbook()
        bksheet = []
        for j in range(L):
            bksheet.append(wkbook.create_sheet('L' + str(j + 1), j))
    with torch.no_grad():
        for i, data in enumerate(testloader, 0):
            images, labels = data
            images = imgSIZEnormalize(images)
            ground_truth = images
            if crop > 0:
                images = Crop(images)
                inputs = rgb2RGGB(images).to(device)
                outputs = net(Variable(inputs))
                outputs = unCrop(outputs)
            else:
                inputs = rgb2RGGB(images).to(device)
                outputs = net(Variable(inputs))

            ground_truth = ground_truth.to(device)
            for j in range(L):
                output = outputs[j]
                pic_psnr[j] = calc_psnr(output, ground_truth, ignore=ignore)
                psnr[j] += pic_psnr[j]

                if single:
                    temp = calc_psnr(output,
                                     ground_truth,
                                     ignore=ignore,
                                     cpsnr=True)
                    bksheet[j].append(['PIC {}'.format(i + 1)] + list(temp))
                    CPSNR[j, :] += temp
                    output = (output.clamp(0.0, 1.0) * 255).cpu().squeeze(0)
                    output = np.array(output,
                                      dtype=np.uint8).transpose([1, 2, 0])
                    output = pil_image.fromarray(output)
                    if not os.path.exists(
                            os.path.join(oup_path, 'L{}'.format(j + 1))):
                        os.mkdir(os.path.join(oup_path, 'L{}'.format(j + 1)))
                    output.save(
                        os.path.join(oup_path, 'L{}'.format(j + 1),
                                     '{}.bmp'.format(i + 1)))
            if single:
                print('No.{} pic... psnr = {} dB.'.format(
                    i + 1, np.array(pic_psnr)))
        for j in range(L):
            psnr[j] /= (i + 1)
        if single:
            CPSNR /= (i + 1)
            for j in range(L):
                bksheet[j].append(['Ave.'] + list(CPSNR[j, :]))
            wkbook_path = os.path.join(
                oup_path, CONFIG.Test_DataSet +
                '_cpsnr_crop_{}_ignore_{}.xlsx'.format(crop, ignore))
            if os.path.exists(wkbook_path):
                os.remove(wkbook_path)
            wkbook.save(wkbook_path)
            print(CPSNR)
            print(psnr)
            print('All {} pics'.format(i + 1))
        return np.array(psnr)
Example #8
0
    #hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    hr = image.crop((0, 0, image_width, image_height))
    lr = hr.resize((hr.width // args.scale, hr.height // args.scale),
                   resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale),
                        resample=pil_image.BICUBIC)
    bicubic.save(
        args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

    lr, _ = preprocess(lr, device)
    hr, _ = preprocess(hr, device)
    _, ycbcr = preprocess(bicubic, device)
    bicubic, _ = preprocess(bicubic, device)

    psnr = calc_psnr(hr, bicubic)
    print('PSNR: {:.2f}'.format(psnr))

    # 可视化原图lr和bicubic
    print(lr.size())
    print(bicubic.size())

    kernel = state_dict['first_part.0.weight']
    show_images('first kernel', kernel)
    show_images("lr image", lr)
    show_images("hr image", hr)
    show_images("bicubic", bicubic)

    model.eval()

    print(state_dict.keys())
            inputs, hr_labels, hq_labels = data

            inputs = inputs.to(device=device, dtype=torch.float32)
            hr_labels = hr_labels.to(device=device, dtype=torch.float32)
            hq_labels = hq_labels.to(device=device, dtype=torch.float32)

            with torch.no_grad():
                hr_preds, hq_preds = model(inputs)

            hr_eval_loss = criterion(hr_preds, hr_labels)
            hq_eval_loss = criterion(hq_preds, hq_labels)

            eval_losses.update(hr_eval_loss.item() + hq_eval_loss.item(),
                               len(inputs))

            hr_epoch_psnr.update(calc_psnr(hr_preds, hr_labels), len(inputs))
            hq_epoch_psnr.update(calc_psnr(hq_preds, hq_labels), len(inputs))

        writer.add_scalar('eval_loss', eval_losses.avg, epoch)
        print('HR eval psnr: {:.2f}'.format(hr_epoch_psnr.avg))
        print('HQ eval psnr :{:.2f}'.format(hq_epoch_psnr.avg))
        writer.add_scalar('hr_psnr_eval', hr_epoch_psnr.avg, epoch)
        writer.add_scalar('hq_psnr_eval', hq_epoch_psnr.avg, epoch)

        hr_pred_grid = torchvision.utils.make_grid(hr_preds)
        hq_pred_grid = torchvision.utils.make_grid(hq_preds)
        writer.add_image('HR prediction epoch : ' + str(epoch), hr_pred_grid)
        writer.add_image('HQ prediction epoch : ' + str(epoch), hq_pred_grid)
        writer.close()

        # best epoch choice is dependant on what output is to be optimized
Example #10
0
    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\nEvaluation:')
        self.model.eval()
        self.ckp.start_log(train=False)
        with torch.no_grad():
            tqdm_test = tqdm(self.loader_test, ncols=80)
            for idx_img, (lr, hr, filename) in enumerate(tqdm_test):
                ycbcr_flag = False
                filename = filename[0][0]
                # lr: [batch_size, n_seq, 3, patch_size, patch_size]
                if self.args.n_colors == 1 and lr.size()[2] == 3:
                    # If n_colors is 1, split image into Y,Cb,Cr
                    ycbcr_flag = True
                    # for CbCr, select the middle frame
                    lr_center_y = lr[:, int(hr.shape[1] / 2),
                                     0:1, :, :].to(self.device)
                    lr_cbcr = lr[:, int(hr.shape[1] / 2),
                                 1:, :, :].to(self.device)
                    hr_cbcr = hr[:, int(hr.shape[1] / 2),
                                 1:, :, :].to(self.device)
                    # extract Y channels (lr should be group, hr should be the center frame)
                    lr = lr[:, :, 0:1, :, :]
                    hr = hr[:, int(hr.shape[1] / 2), 0:1, :, :]

                # Divide LR frame sequence [N, n_sequence, n_colors, H, W] -> n_sequence * [N, 1, n_colors, H, W]
                lr = list(torch.split(lr, self.args.n_colors, dim=1))

                #lr = lr.to(self.device)
                lr = [x.to(self.device) for x in lr]
                hr = hr.to(self.device)

                # output frame = single HR frame [N, n_colors, H, W]
                if self.model.get_model().name == 'ESPCN_mf':
                    sr = self.model(lr)
                elif self.model.get_model().name == 'VESPCN':
                    sr, _, _ = self.model(lr)

                PSNR = utils.calc_psnr(self.args, sr, hr)
                self.ckp.report_log(PSNR, train=False)
                hr, sr = utils.postprocess(hr,
                                           sr,
                                           rgb_range=self.args.rgb_range,
                                           ycbcr_flag=ycbcr_flag,
                                           device=self.device)

                if self.args.save_images and idx_img % 30 == 0:
                    if ycbcr_flag:
                        [lr_center_y
                         ] = utils.postprocess(lr_center_y,
                                               rgb_range=self.args.rgb_range,
                                               ycbcr_flag=ycbcr_flag,
                                               device=self.device)
                        lr = torch.cat((lr_center_y, lr_cbcr), dim=1)
                        hr = torch.cat((hr, hr_cbcr), dim=1)
                        sr = torch.cat((sr, hr_cbcr), dim=1)

                    save_list = [lr, hr, sr]

                    self.ckp.save_images(filename, save_list, self.args.scale)

            self.ckp.end_log(len(self.loader_test), train=False)
            best = self.ckp.psnr_log.max(0)
            self.ckp.write_log(
                '[{}]\taverage PSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                    self.args.data_test, self.ckp.psnr_log[-1], best[0],
                    best[1] + 1))
            if not self.args.test_only:
                self.ckp.save(self, epoch, is_best=(best[1] + 1 == epoch))
Example #11
0
def PSNR(img1_path, img2_path, required_width, required_height):
    img1_y = get_torch_y(img1_path, required_width, required_height)
    img2_y = get_torch_y(img2_path, required_width, required_height)
    psnr = calc_psnr(img1_y, img2_y)
    return psnr
Example #12
0
    sr_img = []
    hr_img = []

    for img in utils.get_image_paths(sr_path):
        img = utils.imread_uint(img, n_channels=1)
        sr_img.append(img)
    
    for img in utils.get_image_paths(hr_path):
        img = utils.imread_uint(img, n_channels=1)
        hr_img.append(img)

    if len(sr_img) != len(hr_img):
        print('ERROR: The number is not equal!')

    mean_rmse = 0
    mean_psnr = 0
    mean_ssim = 0
    for i in range(0, len(sr_img)):
        rmse, _ = utils.calc_rmse(sr_img[i], hr_img[i])
        psnr = utils.calc_psnr(sr_img[i], hr_img[i])
        ssim = utils.calc_ssim(sr_img[i], hr_img[i])

        logger.info('Image:{:03d} || RMSE:{} || PSNR:{} || SSIM:{}'.format(i+1, rmse, psnr, ssim))
        mean_rmse += rmse
        mean_psnr += psnr
        mean_ssim += ssim
    mean_rmse =  mean_rmse / len(sr_img)
    mean_psnr =  mean_psnr / len(sr_img)
    mean_ssim =  mean_ssim / len(sr_img)
    logger.info('AVG RMSE: {} || PSNR:{} || SSIM:{}'.format(mean_rmse, mean_psnr, mean_ssim))
    def do(self, phase, epoch, SR_model, loss, SR_optimizer, tr_dataloader,
           vl_dataloader, te_dataloader):

        if phase == 'train':
            # set model to training mode!
            for model_type in list(SR_model.keys()):
                if (model_type == 'net_G') or (model_type == 'net_D'):
                    SR_model[model_type].train()

            loss_sum = 0.0
            valid_iter_cnt = 0
            for iter, (lr, hr, _) in enumerate(tr_dataloader):
                lr, hr = utils.tensor_prepare([lr, hr], self.args)

                # forward/backward pass
                utils.opt_zerograd(SR_optimizer)
                sr = SR_model['net_G'](lr)
                loss_val = loss.SR_loss(sr, hr)
                self.loss_val = float(loss_val)
                self.lr_G_val = SR_optimizer['net_G'].param_groups[0]["lr"]
                loss_val.backward()

                # skip parameter update when loss is exploded
                if (epoch != 0 and
                        iter != 0) and (loss_val > self.loss_val_prev * 10):
                    print('loss_val: %f\tloss_val_prev: %f\tskip this batch!' %
                          (loss_val, self.loss_val_prev))
                    continue

                # update parameters
                utils.sch_opt_step(SR_optimizer)

                # save current loss to utilize next iteration
                self.loss_val_prev = loss_val
                valid_iter_cnt += 1
                loss_sum += loss_val

                if iter % self.args.print_freq == 0:
                    tr_res_txt = 'epoch: %d\tlr: %f\t%s loss: %05.2f\titer: %d/%d\t[%s]\n' % \
                                 (epoch, self.lr_G_val, self.args.loss, loss_sum/valid_iter_cnt,
                                  iter*self.args.batch_size, len(tr_dataloader.dataset),
                                  datetime.now())

                    self.f_tr_rec = open(self.f_tr_fname, 'at')
                    self.f_tr_rec.write(tr_res_txt)
                    self.f_tr_rec.close()
                    print(tr_res_txt[:len(tr_res_txt) - 1])
                # break # debug

        elif phase == 'valid':
            # set model to test mode!
            SR_model['net_G'].eval()
            val_psnr_avg = 0.0
            val_psnr_cnt = 0

            with torch.no_grad():
                for valiter, (val_lr, val_hr, _) in enumerate(vl_dataloader):
                    val_lr, val_hr = utils.tensor_prepare([val_lr, val_hr],
                                                          self.args)
                    val_sr = SR_model['net_G'](val_lr)
                    val_sr = utils.quantize(val_sr)
                    val_psnr = utils.calc_psnr(val_sr, val_hr, self.args.scale)
                    val_psnr_avg += val_psnr
                    val_psnr_cnt += 1

                val_psnr_avg /= val_psnr_cnt
                val_res_text = 'epoch: %d\tlr: %f\t%s loss: %05.2f\ttrain %s valid %s PSNR avg: %f [%s]\n' % \
                               (epoch, self.lr_G_val, self.args.loss, self.loss_val,
                                self.args.tr_dset_name, self.args.vl_dset_name, float(val_psnr_avg), datetime.now())

                self.f_vl_rec = open(self.f_vl_fname, 'at')
                self.f_vl_rec.write(val_res_text)
                self.f_vl_rec.close()
                print(val_res_text[:len(val_res_text) - 1])

        elif phase == 'test':
            SR_model['net_G'].eval()
            te_psnr_avg = 0.0
            te_psnr_cnt = 0

            with torch.no_grad():
                for te_iter, (te_lr, te_hr,
                              te_name) in tqdm(enumerate(te_dataloader)):
                    self.args.te_name = te_name[0]
                    te_lr, te_hr = utils.tensor_prepare([te_lr, te_hr],
                                                        self.args)

                    if self.args.RRDB_ref:
                        te_lr = te_lr.mul_(1.0 / 255.0)

                    te_sr = SR_model['net_G'](te_lr)
                    if self.args.RRDB_ref:
                        te_lr = te_lr.mul_(255.0)
                        te_sr = te_sr.mul_(255.0)
                    te_sr = utils.quantize(te_sr)

                    if self.args.PSNR_ver == 1 or self.args.PSNR_ver == 3:
                        # original or div4 PSNR
                        te_psnr = utils.calc_psnr(te_sr, te_hr,
                                                  self.args.scale,
                                                  self.args.rgb_range)
                    elif self.args.PSNR_ver == 2:
                        # patch-based PSNR
                        #te_hr = utils.hr_crop_for_pb_forward(te_hr, self.args)
                        te_psnr = utils.calc_psnr_pb_forward(
                            self.args, te_sr, te_hr, self.args.scale,
                            self.args.rgb_range)
                    elif self.args.PSNR_ver == 4:
                        te_psnr = utils.calc_psnr_dpb_forward(
                            self.args, te_sr, te_hr)

                    lr_name = self.args.save_test + '/images/' + te_name[
                        0] + '_LR'
                    sr_name = self.args.save_test + '/images/' + te_name[
                        0] + '_SR'
                    hr_name = self.args.save_test + '/images/' + te_name[
                        0] + '_HR'

                    utils.save_tensor_to_image(self.args, te_lr, lr_name)
                    utils.save_tensor_to_image(self.args, te_hr, hr_name)
                    utils.save_tensor_to_image(self.args, te_sr, sr_name)

                    psnr_txt = '%s\t%f\n' % (te_name[0], te_psnr)
                    self.f_te_rec = open(self.f_te_fname, 'at')
                    self.f_te_rec.write(psnr_txt)
                    self.f_te_rec.close()
                    print(psnr_txt[:len(psnr_txt) - 1])

                    te_psnr_avg += te_psnr
                    te_psnr_cnt += 1

            te_psnr_avg /= te_psnr_cnt

            print('%d of tests are completed, average PSNR: [%.2f]' %
                  (te_iter + 1, te_psnr_avg))
        else:
            print('phase error!')
Example #14
0
    image = pil_image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale

    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((hr.width // args.scale, hr.height // args.scale),
                   resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale),
                        resample=pil_image.BICUBIC)
    bicubic.save(
        args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

    lr, _ = preprocess(lr, device)
    hr, _ = preprocess(hr, device)
    _, ycbcr = preprocess(bicubic, device)

    with torch.no_grad():
        preds = model(lr).clamp(0.0, 1.0)

    psnr = calc_psnr(hr, preds)
    print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[...,
                                                   2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    output.save(args.image_file.replace('.', '_espcn_x{}.'.format(args.scale)))
Example #15
0
        hr = hr[0].detach().cpu().numpy()
        sr = sr[0].detach().cpu().numpy()
        sr = np.transpose(sr, (1, 2, 0))
        hr = np.transpose(hr, (1, 2, 0))

    sr = sr.astype(np.uint8)
    Image.fromarray(sr).save(
        os.path.join(data_root, args.model, data[2] + '.png'))
    hr = hr.astype(np.uint8)
    Image.fromarray(hr).save(os.path.join(data_root, "hr", data[2] + '.png'))

    sr = sr / 255.0
    hr = hr / 255.0

    psnr = utils.calc_psnr(sr, hr, scale=int(args.scale[0]), rgb_range=1.)
    ssim = utils.calc_ssim(sr, hr)
    psnrs.append(psnr)
    ssims.append(ssim)
    print(psnr)
    #print(ssim)
    #print(psnr)

    # plt.subplot(121)
    # plt.imshow(lr.astype(np.uint8))
    #
    # plt.subplot(122)
    # plt.imshow(hr.astype(np.uint8))
    #
    # plt.show()
print(np.mean(np.array(ssims)))
Example #16
0
        cur_w, cur_h = image.width * args.scale, image.height * args.scale
        image = image.resize(
            (image.width * args.scale, image.height * args.scale),
            resample=pil_image.BICUBIC)

        image = np.array(image).astype(np.float32)
        ycbcr = convert_rgb_to_ycbcr(image)

        y = ycbcr[..., 0]
        y /= 255.
        y = torch.from_numpy(y).to(device)
        y = y.unsqueeze(0).unsqueeze(0)

        with torch.no_grad():
            preds = model(y).clamp(0.0, 1.0)

        psnr = calc_psnr(y, preds)
        psnr_seq.append(psnr.cpu().item())
        print('{} PSNR: {:.2f}'.format(image_path.split('/')[-1], psnr))

        preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

        output = np.array([preds, ycbcr[..., 1],
                           ycbcr[..., 2]]).transpose([1, 2, 0])
        output = np.clip(convert_ycbcr_to_rgb(output), 0.0,
                         255.0).astype(np.uint8)
        output = pil_image.fromarray(output)
        output.save(args.outputs_dir + image_path.split('/')[-1])

    print(f'Average PSNR: {np.mean(psnr_seq)}')
Example #17
0
    def test(self, epoch=10):
        self.ckp.write_log('=> Evaluation...')
        timer_test = utils.timer()
        upscale = self.args.upscale
        avg_psnr = {}
        avg_ssim = {}

        for scale in upscale:
            avg_psnr[scale] = 0.0
            avg_ssim[scale] = 0.0

        for iteration, (input, hr) in enumerate(self.loader_test, 1):

            has_target = type(hr) == list  # if test on demo

            if has_target:
                input, hr = self.prepare([input, hr])
            else:
                input = self.prepare([input])[0]

            sr = self.model(input)

            save_list = [*sr, input]

            if has_target:
                save_list.extend(hr)

                psnr = {}
                ssim = {}
                for i, scale in enumerate(upscale):
                    psnr[scale] = utils.calc_psnr(hr[i], sr[i], int(scale))
                    ssim[scale] = utils.calc_ssim(hr[i], sr[i])
                    avg_psnr[scale] += psnr[scale]
                    avg_ssim[scale] += ssim[scale]

            if self.args.save:
                if has_target:
                    for i, scale in enumerate(upscale):
                        self.ckp.write_log(
                            '=> Image{} PSNR_x{}: {:.4f}'.format(
                                iteration, scale, psnr[scale]))
                        self.ckp.write_log(
                            '=> Image{} SSIM_x{}: {:.4f}'.format(
                                iteration, scale, ssim[scale]))
                self.ckp.save_result(iteration, save_list)

        if has_target:
            for scale, value in avg_psnr.items():
                self.ckp.write_log("=> PSNR_x{}: {:.4f}".format(
                    scale, value / len(self.loader_test)))
                self.ckp.write_log("=> SSIM_x{}: {:.4f}".format(
                    scale, avg_ssim[scale] / len(self.loader_test)))

        self.ckp.write_log("=> Total time: {:.1f}s".format(timer_test.toc()))

        if not self.args.test:
            self.ckp.save_model(self.model, 'latest')
            cur_psnr = avg_psnr[upscale[-1]]
            if self.best_psnr < cur_psnr:
                self.best_psnr = cur_psnr
                self.best_epoch = epoch
                self.ckp.save_model(self.model,
                                    '{}_best'.format(self.best_epoch))
Example #18
0
                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(
            model.state_dict(),
            os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
Example #19
0
    def test(self):
        self.ckp.write_log('\nEvaluation:')
        self.model.eval()
        self.ckp.start_log(train=False)
        self.ckp.start_log(train=False, key='ssim')
        with torch.no_grad():
            tqdm_test = tqdm(self.loader_test, ncols=80)
            for idx_img, data_pack in enumerate(tqdm_test):
                if self.args.real:
                    lr, filename = data_pack
                else:
                    lr, hr, kernels, filename = data_pack
                ycbcr_flag = False
                filename = filename[len(filename) // 2]
                # lr: [batch_size, n_seq, 3, patch_size, patch_size]
                if self.args.n_colors == 1 and lr.size()[2] == 3:
                    lr = lr[:, :, 0:1, :, :]
                    if not self.args.real:
                        hr = hr[:, :, 0:1, :, :]

                # Divide LR frame sequence [N, n_sequence, n_colors, H, W] -> N * [1, n_sequence, n_colors, H, W]
                # We need seperate on first dimension because we want to keep sequence order when re-concact
                lr = list(torch.split(lr, 1, dim=0))
                lr = [x.to(self.device) for x in lr]
                lr = [torch.squeeze(x, dim=0) for x in lr]
                lr = torch.cat(lr, dim=0)
                if not self.args.real:
                    hr = list(torch.split(hr, 1, dim=0))
                    center = self.args.n_sequence // 2
                    center_hr = [x[:, center, :, :, :] for x in hr]
                    center_hr = [x.to(self.device) for x in center_hr]
                    center_hr = torch.cat(center_hr, dim=0)

                    hr = [x.to(self.device) for x in hr]
                    hr = [torch.squeeze(x, dim=0) for x in hr]
                    hr = torch.cat(hr, dim=0)
                cur_kernel_pca = None

                sr, _, _, = self.model(lr, cur_kernel_pca)
                sr = torch.clamp(sr, min=0.0, max=1.0)
                if not self.args.real:
                    PSNR = utils.calc_psnr(self.args, sr, center_hr)
                    SSIM = utils.calc_ssim(self.args, sr, center_hr)
                    self.ckp.report_log(PSNR, train=False)
                    self.ckp.report_log(SSIM, train=False, key='ssim')

                if self.args.save_images and idx_img % 30 == 0 or self.args.test_only:

                    if self.args.real:
                        save_list = [sr]
                    else:
                        save_list = [sr]

                    filename = filename[0]
                    self.ckp.save_images(filename, save_list, self.args.scale)

            self.ckp.end_log(len(self.loader_test), train=False)
            self.ckp.end_log(len(self.loader_test), train=False, key='ssim')
            best = self.ckp.psnr_log.max(0)
            self.ckp.write_log(
                '[{}]\taverage PSNR: {:.3f} , average SSIM: {:.3f} (Best: {:.3f} @epoch {})'
                .format(self.args.data_test, self.ckp.psnr_log[-1],
                        self.ckp.ssim_log[-1], best[0], best[1] + 1))
            if not self.args.test_only:
                self.ckp.save(self,
                              self.epoch,
                              is_best=(best[1] + 1 == self.epoch))