def denois_example(index, netName='dncnn_CNN16', sigma=0.1, num_copy=1, dataDir='/home/yuqi/spinner/dataset/stl10'): testset = dataset.noisy_stl10(sigma, num_copy=num_copy, dataDir=dataDir, train=False) noisy = testset.test_data_noisy[index] # range [0,1] img = testset.test_data[int(index // num_copy)] # range [0,1] denoised = get_output(in_img=noisy, netName=netName, sigma=sigma, num_copy=num_copy) psnr = PSNR((img - denoised)) print(np.max(img), np.min(img)) print(np.max(denoised), np.min(denoised)) imsave('../result/test_noisy_%d.jpg' % (index), noisy) imsave('../result/test_img_%d.jpg' % (index), img) imsave('../result/test_denoised_%d.jpg' % (index), denoised) return psnr
# Data Part dataDir = '/home/yuqi/spinner/dataset/stl10' # dataDir = '../stl10' # Running on Gcloud num_train = 2000 # max 5000 num_test = 5000 batch_size_train = 32 batch_size_test = 20 best_accu = 0 img_transform = transforms.Compose([transforms.ToTensor()]) trainset = dataset.noisy_stl10(sigma, num_train=num_train, num_test=num_test, num_copy=num_copy, dataDir=dataDir, transform=img_transform, train=True) trainloader = DataLoader(trainset, batch_size=batch_size_train, shuffle=True) testset = dataset.noisy_stl10(sigma, num_train=num_train, num_test=num_test, num_copy=num_copy, dataDir=dataDir, transform=img_transform, train=False) testloader = DataLoader(testset, batch_size=batch_size_test, shuffle=True) num_epochs = args.epoch reg_term = args.reg_term
def get_denoised_dataset(dataDir, outDir, sigma, netName, num_copy=1): checkpoint = torch.load('./checkpoints/ckpt_%s_sigma%.2f_copy%d.t7' % (netName, sigma, num_copy)) net = models.dncnn.deepcnn(netName[6:]).cuda() net = torch.nn.DataParallel(net) cudnn.benchmark = True net.load_state_dict(checkpoint['net']) device = 'cuda' img_transform = transforms.Compose([transforms.ToTensor()]) batch_size_train = 32 batch_size_test = 20 trainset = dataset.noisy_stl10(sigma, num_copy=num_copy, dataDir=dataDir, transform=img_transform, train=True) trainloader = DataLoader(trainset, batch_size=batch_size_train, shuffle=False) testset = dataset.noisy_stl10(sigma, num_copy=num_copy, dataDir=dataDir, transform=img_transform, train=False) testloader = DataLoader(testset, batch_size=batch_size_test, shuffle=False) train_data_denoised = np.zeros(trainset.train_data.shape) test_data_denoised = np.zeros(testset.test_data.shape) with torch.no_grad(): img_idx = 0 for batch_idx, (noisy, img, targets) in enumerate(trainloader): noisy, img, targets = noisy.to(device), img.to(device), targets.to( device) outputs = net(noisy) # now N*3*96*96 outputs = outputs.cpu().detach().numpy().transpose( (0, 2, 3, 1)) # now N*96*96*3 train_data_denoised[batch_idx * batch_size_train:(batch_idx + 1) * batch_size_train] = outputs # for i in range(len(outputs)): # if img_idx%500 == 0: # out_img = np.clip(outputs[i],0,1) # imsave('%s/images/%s_sigma%.2f_train_denoised_%d.jpg'%(outDir,netName,sigma,img_idx),out_img) # img_idx += 1 with torch.no_grad(): img_idx = 0 for batch_idx, (noisy, img, targets) in enumerate(testloader): noisy, img, targets = noisy.to(device), img.to(device), targets.to( device) outputs = net(noisy) # now N*3*96*96 outputs = outputs.cpu().detach().numpy().transpose( (0, 2, 3, 1)) # now N*96*96*3 test_data_denoised[batch_idx * batch_size_test:(batch_idx + 1) * batch_size_test] = outputs for i in range(len(outputs)): if img_idx % 100 == 0: out_img = np.clip(outputs[i], 0, 1) imsave( '%s/images/%s_sigma%.2f_test_denoised_%d.jpg' % (outDir, netName, sigma, img_idx), out_img) img_clean = img.cpu().detach().numpy().transpose( (0, 2, 3, 1))[i] imsave( '%s/images/%s_sigma%.2f_test_img_%d.jpg' % (outDir, netName, sigma, img_idx), img_clean) noisy = noisy.cpu().detach().numpy().transpose( (0, 2, 3, 1))[i] imsave( '%s/images/%s_sigma%.2f_test_noisy_%d.jpg' % (outDir, netName, sigma, img_idx), noisy) img_idx += 1 fileName = '%s/denoisedSTL10_%s_sigma%.2f_copy%d.npz' % (dataDir, netName, sigma, num_copy) np.savez(fileName, train_data_denoised=train_data_denoised, train_labels=trainset.train_labels, train_data=trainset.train_data, train_data_noisy=trainset.train_data_noisy, test_data_denoised=test_data_denoised, test_labels=testset.test_labels, test_data=testset.test_data, test_data_noisy=testset.test_data_noisy, num_copy=num_copy, sigma=sigma, netName=netName)