예제 #1
0
    def get_figure(LR, HR, pred, pred_teacher=None):
        upscaled_lr = nn.functional.interpolate(LR, scale_factor=scale, mode='bicubic')
        upscaled_lr = upscaled_lr[0]
        LR = LR[0]
        HR = HR[0]
        pred_residual_hr_mu = pred['hr'][0] - upscaled_lr
        for key in pred.keys():
            if '_var' in key:
                sigma_key = key # get the last key for sigma visualization
        pred_residual_hr_sigma = pred[sigma_key][0].sum(0, keepdim=True)
        pred_hr = pred['hr'][0]
        gt_diff = torch.abs(HR - pred_hr)

        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16,4))
        cmap = 'gray'
        ax1.imshow(float2uint8(quantize(pred_residual_hr_mu, rgb_range)), cmap=cmap)
        ax1.set_title('pred_residual_hr_mu, mean_val : %.4f'%torch.abs(pred_residual_hr_mu).mean())
        ax2.imshow(float2uint8(pred_residual_hr_sigma, normalize=True), cmap=cmap)
        ax2.set_title('pred_residual_hr_sigma, mean_val : %.4f'%torch.abs(pred_residual_hr_sigma).mean())
        ax3.imshow(float2uint8(quantize(pred_hr, rgb_range)), cmap=cmap)
        ax3.set_title('pred_hr, mean_val : %.4f'%torch.abs(pred_hr).mean())
        ax4.imshow(float2uint8(quantize(gt_diff, rgb_range)), cmap=cmap)
        ax4.set_title('GT - pred_hr, mean_val : %.4f'%torch.abs(HR-pred_hr).mean())

        return fig
예제 #2
0
def valid():
    model.eval()
    avg_psnr, avg_ssim = 0, 0
    for i, batch in enumerate(testing_data_loader):
        lr_tensor, hr_tensor = batch[0], batch[1]
        if args.cuda:
            lr_tensor = lr_tensor.to(device)
            hr_tensor = hr_tensor.to(device)

        with torch.no_grad():
            pre = model(lr_tensor)

        sr_img = utils.tensor2np(pre.detach()[0])
        gt_img = utils.tensor2np(hr_tensor.detach()[0])
        crop_size = args.scale
        cropped_sr_img = utils.shave(sr_img, crop_size)
        cropped_gt_img = utils.shave(gt_img, crop_size)
        if args.isY is True:
            im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
            im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
        else:
            im_label = cropped_gt_img
            im_pre = cropped_sr_img

        psnr = utils.compute_psnr(im_pre, im_label)
        ssim = utils.compute_ssim(im_pre, im_label)

        avg_psnr += psnr
        avg_ssim += ssim
        print(
            f" Valid {i}/{len(testing_data_loader)} with PSNR = {psnr} and SSIM = {ssim}"
        )
    print("===> Valid. psnr: {:.4f}, ssim: {:.4f}".format(
        avg_psnr / len(testing_data_loader),
        avg_ssim / len(testing_data_loader)))
예제 #3
0
    def get_figure(LR, HR, pred_student, pred_teacher):
        upscaled_lr = nn.functional.interpolate(LR, scale_factor=scale, mode='bicubic')
        upscaled_lr = upscaled_lr[0]
        LR = LR[0]
        HR = HR[0]

        pred_student_hr = pred_student['hr'][0]
        pred_teacher_hr = pred_teacher['hr'][0]
        pred_student_residual_hr = pred_student_hr - upscaled_lr
        pred_teacher_residual_hr = pred_teacher_hr - upscaled_lr

        residual_diff = torch.abs(pred_teacher_residual_hr - pred_student_residual_hr)
        gt_diff = torch.abs(HR - pred_student_hr)

        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16,4))
        cmap = 'gray'
        ax1.imshow(float2uint8(quantize(pred_student_residual_hr, rgb_range), rgb_range), cmap=cmap)
        ax1.set_title('pred_s_residual_hr, mean_val : %.4f'% torch.abs(pred_student_residual_hr).mean())
        ax2.imshow(float2uint8(quantize(pred_student_hr, rgb_range), rgb_range), cmap=cmap)
        ax2.set_title('pred_s_hr, mean_val : %.4f'%torch.abs(pred_student_hr).mean())
        ax3.imshow(float2uint8(quantize(residual_diff, rgb_range), rgb_range), cmap=cmap)
        ax3.set_title('residual_diff, mean_val : %.4f'%torch.abs(residual_diff).mean())
        ax4.imshow(float2uint8(quantize(gt_diff, rgb_range), rgb_range), cmap=cmap)
        ax4.set_title('GT - pred_hr, mean_val : %.4f'%torch.abs(gt_diff).mean())

        return fig
예제 #4
0
def evaluate_single_epoch(config, student_model, teacher_model, dataloader,
                          criterion, epoch, writer, visualizer, postfix_dict,
                          eval_type):
    teacher_model.eval()
    student_model.eval()
    with torch.no_grad():
        batch_size = config.eval.batch_size
        total_size = len(dataloader.dataset)
        total_step = math.ceil(total_size / batch_size)

        tbar = tqdm.tqdm(enumerate(dataloader), total=total_step)

        total_psnr = 0
        total_loss = 0
        for i, (LR_img, HR_img, filepath) in tbar:
            HR_img = HR_img[:, :1].to(device)
            LR_img = LR_img[:, :1].to(device)

            student_pred_dict = student_model.forward(LR=LR_img)
            pred_hr = student_pred_dict['hr']
            total_loss += criterion['val'](pred_hr, HR_img).item()

            pred_hr = quantize(pred_hr, config.data.rgb_range)
            total_psnr += get_psnr(pred_hr,
                                   HR_img,
                                   config.data.scale,
                                   config.data.rgb_range,
                                   benchmark=eval_type == 'test')

            f_epoch = epoch + i / total_step
            desc = '{:5s}'.format(eval_type)
            desc += ', {:06d}/{:06d}, {:.2f} epoch'.format(
                i, total_step, f_epoch)
            tbar.set_description(desc)
            tbar.set_postfix(**postfix_dict)

            # for test
            teacher_pred_dict = teacher_model.forward(LR=LR_img, HR=HR_img)

            if writer is not None and eval_type == 'test':
                fig = visualizer(LR_img, HR_img, student_pred_dict,
                                 teacher_pred_dict)
                writer.add_figure('{}/{:04d}'.format(eval_type, i),
                                  fig,
                                  global_step=epoch)


#         print(total_pseudo_psnr / (i+1))
        log_dict = {}
        avg_loss = total_loss / (i + 1)
        avg_psnr = total_psnr / (i + 1)
        log_dict['loss'] = avg_loss
        log_dict['psnr'] = avg_psnr

        for key, value in log_dict.items():
            if writer is not None:
                writer.add_scalar('{}/{}'.format(eval_type, key), value, epoch)
            postfix_dict['{}/{}'.format(eval_type, key)] = value

        return avg_psnr
예제 #5
0
    def get_figure_basic_fn(LR, HR, pred):
        upscaled_lr = nn.functional.interpolate(LR, scale_factor=scale, mode='bicubic')
        upscaled_lr = upscaled_lr[0]
        LR = LR[0]
        HR = HR[0]
        pred_hr = pred['hr'][0]
        pred_residual_hr = pred_hr - upscaled_lr
        gt_diff = torch.abs(HR - pred_hr)

        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16,4))
        cmap = 'gray'
        ax1.imshow(float2uint8(quantize(pred_residual_hr, rgb_range), rgb_range), cmap=cmap)
        ax1.set_title('pred_residual_hr, mean_val : %.4f'%torch.abs(pred_residual_hr).mean())
        ax2.imshow(float2uint8(quantize(pred_hr, rgb_range), rgb_range), cmap=cmap)
        ax2.set_title('pred_hr, mean_val : %.4f'%torch.abs(pred_hr).mean())
        ax3.imshow(float2uint8(quantize(HR, rgb_range), rgb_range), cmap=cmap)
        ax3.set_title('ground_truth, mean_val : %.4f'%torch.abs(HR).mean())
        ax4.imshow(float2uint8(quantize(gt_diff, rgb_range), rgb_range), cmap=cmap)
        ax4.set_title('GT - pred_hr, mean_val : %.4f'%torch.abs(HR-pred_hr).mean())

        return fig
    def get_figure(LR, HR, pred_student, pred_teacher):
        LR = LR[0]
        HR = HR[0]
        pred_teacher_residual_hr = pred_teacher['residual_hr'][0]
        pred_student_residual_hr = pred_student['residual_hr'][0]
        pred_student_hr = pred_student['hr'][0]
        pred_teacher_hr = pred_teacher['hr'][0]
        residual_diff = torch.abs(pred_teacher_residual_hr - pred_student_residual_hr)
        gt_diff = torch.abs(HR - pred_student_hr)

        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16,4))
        cmap = 'gray'
        ax1.imshow(float2uint8(quantize(pred_student_residual_hr, rgb_range), rgb_range), cmap=cmap)
        ax1.set_title('pred_s_residual_hr, mean_val : %.4f'% torch.abs(pred_student_residual_hr).mean())
        ax2.imshow(float2uint8(quantize(pred_student_hr, rgb_range), rgb_range), cmap=cmap)
        ax2.set_title('pred_s_hr, mean_val : %.4f'%torch.abs(pred_student_hr).mean())
        ax3.imshow(float2uint8(quantize(residual_diff, rgb_range), rgb_range), cmap=cmap)
        ax3.set_title('residual_diff, mean_val : %.4f'%torch.abs(residual_diff).mean())
        ax4.imshow(float2uint8(quantize(gt_diff, rgb_range), rgb_range), cmap=cmap)
        ax4.set_title('GT - pred_hr, mean_val : %.4f'%torch.abs(gt_diff).mean())

        return fig
예제 #7
0
    def valid(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\n\nEvaluation during search process:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)))
        self.model.eval()

        timer_valid = utils.timer()
        with torch.no_grad():
            eval_psnr = 0
            eval_ssim = 0
            for batch, (_input, _target, _, idx_scale) in enumerate(self.loader_valid):
                _input, _target = self.prepare(_input.detach(), _target.detach())

                timer_valid.tic()
                logits = self.model(_input)
                timer_valid.hold()
                logits = utils.quantize(logits, self.args.rgb_range)

                eval_psnr += utils.calc_psnr(
                    logits, _target, self.scale[idx_scale], self.args.rgb_range,
                    benchmark=False
                )
                eval_ssim += utils.calc_batch_ssim(
                    logits, _target, self.scale[idx_scale],
                    benchmark=False
                )

            self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_valid)

            best = self.ckp.log.max(0)
            self.ckp.write_log(
                '[{} x{}]\tPSNR: {:.3f}\tSSIM: {:.4f}\t(best: {:.3f} @epoch {})'.format(
                    self.args.data_valid,
                    self.scale[idx_scale],
                    self.ckp.log[-1, idx_scale],
                    eval_ssim / len(self.loader_valid),
                    best[0][idx_scale],
                    best[1][idx_scale] + 1
                )
            )
            self.ckp.visual("valid_PSNR", self.ckp.log[-1, idx_scale], epoch)
            self.ckp.visual("valid_SSIM", eval_ssim /
                            len(self.loader_valid), epoch)

        self.ckp.write_log(
            'Total time: {:.2f}s\n'.format(timer_valid.toc()), refresh=True
        )
예제 #8
0
파일: evaluate.py 프로젝트: zhwzhong/PISR
def evaluate_single_epoch(config, student_model, dataloader_dict, eval_type):
    student_model.eval()
    log_dict = {}
    with torch.no_grad():
        for name, dataloader in dataloader_dict.items():
            print('evaluate %s'%(name))
            batch_size = config.eval.batch_size
            total_size = len(dataloader.dataset)
            total_step = math.ceil(total_size / batch_size)

            tbar = tqdm.tqdm(enumerate(dataloader), total=total_step)

            total_psnr = 0
            total_iter = 0
            for i, (LR_img, HR_img, filepath) in tbar:
                HR_img = HR_img.to(device)
                LR_img = LR_img.to(device)

                student_pred_dict = student_model.forward(LR=LR_img)
                pred_hr = student_pred_dict['hr']
                pred_hr = quantize(pred_hr, config.data.rgb_range)
                total_psnr += get_psnr(pred_hr, HR_img, config.data.scale,
                                      config.data.rgb_range,
                                      benchmark=eval_type=='test')

                f_epoch = i / total_step
                desc = '{:5s}'.format(eval_type)
                desc += ', {:06d}/{:06d}, {:.2f} epoch'.format(i, total_step, f_epoch)
                tbar.set_description(desc)
                total_iter = i

            avg_psnr = total_psnr / (total_iter+1)
            log_dict[name] = avg_psnr
            print('%s : %.3f'%(name, avg_psnr))
            
    return log_dict
예제 #9
0
        model = model.to(device)
        im_input = im_input.to(device)

    with torch.no_grad():
        start.record()
        out = model(im_input)
        end.record()
        torch.cuda.synchronize()
        time_list[i] = start.elapsed_time(end)  # milliseconds

    out_img = utils.tensor2np(out.detach()[0])
    crop_size = opt.upscale_factor
    cropped_sr_img = utils.shave(out_img, crop_size)
    cropped_gt_img = utils.shave(im_gt, crop_size)
    if opt.is_y is True:
        im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
        im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
    else:
        im_label = cropped_gt_img
        im_pre = cropped_sr_img
    psnr_list[i] = utils.compute_psnr(im_pre, im_label)
    ssim_list[i] = utils.compute_ssim(im_pre, im_label)

    output_folder = os.path.join(
        opt.output_folder,
        imname.split('/')[-1].split('.')[0] + 'x' + str(opt.upscale_factor) +
        '.png')

    if not os.path.exists(opt.output_folder):
        os.makedirs(opt.output_folder)
예제 #10
0
    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('\nEvaluation: ')
        self.ckp.add_log(torch.zeros(1, len(self.loader_test),
                                     len(self.scale)))
        self.model.eval()

        timer_test = utils.timer()
        if self.args.save_results:
            self.ckp.begin_background()
        with torch.no_grad():
            for idx_data, d in enumerate(self.loader_test):
                for idx_scale, scale in enumerate(self.scale):
                    d.dataset.set_scale(idx_scale)
                    eval_acc = 0
                    eval_acc_ssim = 0
                    for _input, _target, filename, _ in tqdm(d, ncols=80):
                        filename = filename[0]

                        _input, _target = self.prepare(_input, _target)

                        timer_test.tic()
                        logits = self.model(_input, idx_scale)
                        timer_test.hold()
                        logits = utils.quantize(logits, self.args.rgb_range)

                        save_list = [logits]
                        eval_acc += utils.calc_psnr(
                            logits,
                            _target,
                            self.scale[idx_scale],
                            self.args.rgb_range,
                            benchmark=d.dataset.benchmark)
                        eval_acc_ssim += utils.calc_ssim(
                            logits,
                            _target,
                            self.scale[idx_scale],
                            benchmark=d.dataset.benchmark)
                        save_list.extend([_input, _target])

                        if self.args.save_results:
                            self.ckp.save_results(filename, save_list,
                                                  self.scale[idx_scale])

                    self.ckp.log[-1, idx_data, idx_scale] = eval_acc / len(d)
                    best = self.ckp.log.max(0)

                    self.ckp.write_log(
                        '\n[{} x{}]\tPSNR: {:.3f}\tSSIM: {:.4f}(Best: {:.3f} @epoch {})'
                        .format(d.dataset.name, self.scale[idx_scale],
                                self.ckp.log[-1, idx_data, idx_scale],
                                eval_acc_ssim / len(d), best[0][idx_data,
                                                                idx_scale],
                                best[1][idx_data, idx_scale] + 1))
                    if len(self.scale) == 1 and len(self.loader_test) == 1:
                        self.ckp.visual("valid_PSNR",
                                        self.ckp.log[-1, idx_data,
                                                     idx_scale], epoch)
                        self.ckp.visual("valid_SSIM", eval_acc_ssim / len(d),
                                        epoch)

        self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
        self.ckp.write_log('Saving...')

        if self.args.save_results:
            self.ckp.end_background()

        if not self.args.test_only:
            self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))

        self.ckp.write_log('Total: {:.2f}s\n'.format(timer_test.toc()),
                           refresh=True)