def test(args): # device device = torch.device("cuda:%d" % args.gpu if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True # data testset = SonyTestDataset(args.input_dir, args.gt_dir) test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) # model model = Unet() model.load_state_dict(torch.load(args.model)) model.to(device) model.eval() # testing for i, databatch in tqdm(enumerate(test_loader), total=len(test_loader)): input_full, scale_full, gt_full, test_id, ratio = databatch scale_full, gt_full = torch.squeeze(scale_full), torch.squeeze(gt_full) # processing inputs = input_full.to(device) outputs = model(inputs) outputs = outputs.cpu().detach() outputs = torch.squeeze(outputs) outputs = outputs.permute(1, 2, 0) # scaling can clipping outputs, scale_full, gt_full = outputs.numpy(), scale_full.numpy( ), gt_full.numpy() scale_full = scale_full * np.mean(gt_full) / np.mean( scale_full ) # scale the low-light image to the same mean of the ground truth outputs = np.minimum(np.maximum(outputs, 0), 1) # saving if not os.path.isdir(os.path.join(args.result_dir, 'eval')): os.makedirs(os.path.join(args.result_dir, 'eval')) scipy.misc.toimage( scale_full * 255, high=255, low=0, cmin=0, cmax=255).save( os.path.join( args.result_dir, 'eval', '%05d_00_train_%d_scale.jpg' % (test_id[0], ratio[0]))) scipy.misc.toimage( outputs * 255, high=255, low=0, cmin=0, cmax=255).save( os.path.join( args.result_dir, 'eval', '%05d_00_train_%d_out.jpg' % (test_id[0], ratio[0]))) scipy.misc.toimage( gt_full * 255, high=255, low=0, cmin=0, cmax=255).save( os.path.join( args.result_dir, 'eval', '%05d_00_train_%d_gt.jpg' % (test_id[0], ratio[0])))
def test(args): # device device = torch.device("cuda:%d" % args.gpu if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True # images path fns = glob.glob(path.join(args.imgdir, '*.DNG')) n = len(fns) # model model = Unet() model.load_state_dict(torch.load(args.model)) model.to(device) model.eval() # ratio ratio = 200 for idx in range(n): fn = fns[idx] print(fn) raw = rawpy.imread(fn) input = np.expand_dims(pack_raw(raw), axis=0) * ratio scale_full = np.expand_dims(np.float32(input / 65535.0), axis=0) input = crop_center(input, 1024, 1024) input = torch.from_numpy(input) input = torch.squeeze(input) input = input.permute(2, 0, 1) input = torch.unsqueeze(input, dim=0) input = input.to(device) outputs = model(input) outputs = outputs.cpu().detach() outputs = torch.squeeze(outputs) outputs = outputs.permute(1, 2, 0) outputs = outputs.numpy() outputs = np.minimum(np.maximum(outputs, 0), 1) scale_full = torch.from_numpy(scale_full) scale_full = torch.squeeze(scale_full) scipy.misc.toimage(outputs * 255, high=255, low=0, cmin=0, cmax=255).save( path.join(args.imgdir, path.basename(fn) + '_out.jpg'))