class Test: def __init__(self): self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) self.test_dir_pre = args.dir_test self.over_imgs = os.listdir(self.test_dir_pre + 'lr_over/') self.over_imgs.sort() self.under_imgs = os.listdir(self.test_dir_pre + 'lr_under/') self.under_imgs.sort() assert len(self.over_imgs) == len(self.under_imgs) self.num_imgs = len(self.over_imgs) self.model = CFNet().cuda() self.state = torch.load(args.model_path + args.model) self.model.load_state_dict(self.state['model']) self.test_time = [] def test(self): self.model.eval() with torch.no_grad(): for idx in trange(self.num_imgs): img1 = cv2.imread(self.test_dir_pre + 'lr_over/' + self.over_imgs[idx]) img1 = torch.unsqueeze(self.transform(img1), 0) img2 = cv2.imread(self.test_dir_pre + 'lr_under/' + self.under_imgs[idx]) img2 = torch.unsqueeze(self.transform(img2), 0) assert img1.shape == img2.shape save_name = os.path.splitext( os.path.split(self.over_imgs[idx])[1])[0] img1 = img1.cuda() img2 = img2.cuda() torch.cuda.synchronize() start_time = time.time() sr_over, sr_under = self.model(img1, img2) img_fused = 0.5 * sr_over[-1] + 0.5 * sr_under[-1] img_fused = img_fused.squeeze(0) torch.cuda.synchronize() end_time = time.time() self.test_time.append(end_time - start_time) img_fused = img_fused.cpu().numpy() img_fused = np.transpose(img_fused, (1, 2, 0)) img_fused = img_fused.astype(np.uint8) cv2.imwrite( os.path.join(args.save_dir, str(save_name) + args.ext), img_fused) print('The average testing time is {:.4f} s.'.format( np.mean(self.test_time)))
class Validation(object): def __init__(self): self.psnr_list = [] self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) self.val_dir_pre = args.dir_val self.gt_imgs = os.listdir(self.val_dir_pre + 'gt/') self.over_imgs = os.listdir(self.val_dir_pre + 'lr_over/') self.under_imgs = os.listdir(self.val_dir_pre + 'lr_under/') assert len(self.over_imgs) == len(self.under_imgs) self.num_imgs = len(self.over_imgs) self.model = CFNet().cuda() self.state = torch.load(args.model_path + 'latest.pth') self.model.load_state_dict(self.state['model']) def validation(self): ep_psnr_list = [] self.model.eval() with torch.no_grad(): for idx in trange(self.num_imgs): img1 = cv2.imread(self.val_dir_pre + 'lr_over/' + self.over_imgs[idx]) img1 = torch.unsqueeze(self.transform(img1), 0) img2 = cv2.imread(self.val_dir_pre + 'lr_under/' + self.under_imgs[idx]) img2 = torch.unsqueeze(self.transform(img2), 0) img_gt = cv2.imread(self.val_dir_pre + 'gt/' + self.gt_imgs[idx]) assert img1.shape == img2.shape img1 = img1.cuda() img2 = img2.cuda() sr_over, sr_under = self.model(img1, img2) img_fused = 0.5 * sr_over[-1] + 0.5 * sr_under[-1] img_fused = img_fused.squeeze(0) img_fused = img_fused.cpu().numpy() img_fused = np.transpose(img_fused, (1, 2, 0)) img_fused = img_fused.astype(np.uint8) psnr_idx = self.calc_psnr(img_fused, img_gt) ep_psnr_list.append(psnr_idx) return np.mean(ep_psnr_list) def calc_psnr(self, img1, img2): mse = np.mean((img1 / 255. - img2 / 255.)**2) pixel_max = 1. return 20 * math.log10(pixel_max / math.sqrt(mse))