Ejemplo n.º 1
0
class Test:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.test_dir_pre = args.dir_test
        self.over_imgs = os.listdir(self.test_dir_pre + 'lr_over/')
        self.over_imgs.sort()
        self.under_imgs = os.listdir(self.test_dir_pre + 'lr_under/')
        self.under_imgs.sort()
        assert len(self.over_imgs) == len(self.under_imgs)
        self.num_imgs = len(self.over_imgs)

        self.model = CFNet().cuda()
        self.state = torch.load(args.model_path + args.model)
        self.model.load_state_dict(self.state['model'])

        self.test_time = []

    def test(self):
        self.model.eval()
        with torch.no_grad():
            for idx in trange(self.num_imgs):
                img1 = cv2.imread(self.test_dir_pre + 'lr_over/' +
                                  self.over_imgs[idx])
                img1 = torch.unsqueeze(self.transform(img1), 0)
                img2 = cv2.imread(self.test_dir_pre + 'lr_under/' +
                                  self.under_imgs[idx])
                img2 = torch.unsqueeze(self.transform(img2), 0)

                assert img1.shape == img2.shape
                save_name = os.path.splitext(
                    os.path.split(self.over_imgs[idx])[1])[0]

                img1 = img1.cuda()
                img2 = img2.cuda()
                torch.cuda.synchronize()
                start_time = time.time()

                sr_over, sr_under = self.model(img1, img2)
                img_fused = 0.5 * sr_over[-1] + 0.5 * sr_under[-1]
                img_fused = img_fused.squeeze(0)

                torch.cuda.synchronize()
                end_time = time.time()
                self.test_time.append(end_time - start_time)

                img_fused = img_fused.cpu().numpy()
                img_fused = np.transpose(img_fused, (1, 2, 0))
                img_fused = img_fused.astype(np.uint8)

                cv2.imwrite(
                    os.path.join(args.save_dir,
                                 str(save_name) + args.ext), img_fused)

            print('The average testing time is {:.4f} s.'.format(
                np.mean(self.test_time)))
Ejemplo n.º 2
0
class Validation(object):
    def __init__(self):
        self.psnr_list = []
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.val_dir_pre = args.dir_val
        self.gt_imgs = os.listdir(self.val_dir_pre + 'gt/')
        self.over_imgs = os.listdir(self.val_dir_pre + 'lr_over/')
        self.under_imgs = os.listdir(self.val_dir_pre + 'lr_under/')
        assert len(self.over_imgs) == len(self.under_imgs)
        self.num_imgs = len(self.over_imgs)

        self.model = CFNet().cuda()
        self.state = torch.load(args.model_path + 'latest.pth')
        self.model.load_state_dict(self.state['model'])

    def validation(self):
        ep_psnr_list = []
        self.model.eval()
        with torch.no_grad():
            for idx in trange(self.num_imgs):
                img1 = cv2.imread(self.val_dir_pre + 'lr_over/' +
                                  self.over_imgs[idx])
                img1 = torch.unsqueeze(self.transform(img1), 0)
                img2 = cv2.imread(self.val_dir_pre + 'lr_under/' +
                                  self.under_imgs[idx])
                img2 = torch.unsqueeze(self.transform(img2), 0)
                img_gt = cv2.imread(self.val_dir_pre + 'gt/' +
                                    self.gt_imgs[idx])

                assert img1.shape == img2.shape

                img1 = img1.cuda()
                img2 = img2.cuda()

                sr_over, sr_under = self.model(img1, img2)
                img_fused = 0.5 * sr_over[-1] + 0.5 * sr_under[-1]
                img_fused = img_fused.squeeze(0)

                img_fused = img_fused.cpu().numpy()
                img_fused = np.transpose(img_fused, (1, 2, 0))
                img_fused = img_fused.astype(np.uint8)

                psnr_idx = self.calc_psnr(img_fused, img_gt)
                ep_psnr_list.append(psnr_idx)
        return np.mean(ep_psnr_list)

    def calc_psnr(self, img1, img2):
        mse = np.mean((img1 / 255. - img2 / 255.)**2)
        pixel_max = 1.
        return 20 * math.log10(pixel_max / math.sqrt(mse))