Example #1
0
def denois_example(index,
                   netName='dncnn_CNN16',
                   sigma=0.1,
                   num_copy=1,
                   dataDir='/home/yuqi/spinner/dataset/stl10'):
    testset = dataset.noisy_stl10(sigma,
                                  num_copy=num_copy,
                                  dataDir=dataDir,
                                  train=False)
    noisy = testset.test_data_noisy[index]  # range [0,1]
    img = testset.test_data[int(index // num_copy)]  # range [0,1]
    denoised = get_output(in_img=noisy,
                          netName=netName,
                          sigma=sigma,
                          num_copy=num_copy)
    psnr = PSNR((img - denoised))
    print(np.max(img), np.min(img))
    print(np.max(denoised), np.min(denoised))
    imsave('../result/test_noisy_%d.jpg' % (index), noisy)
    imsave('../result/test_img_%d.jpg' % (index), img)
    imsave('../result/test_denoised_%d.jpg' % (index), denoised)
    return psnr
Example #2
0
# Data Part
dataDir = '/home/yuqi/spinner/dataset/stl10'
# dataDir = '../stl10' # Running on Gcloud
num_train = 2000  # max 5000
num_test = 5000
batch_size_train = 32
batch_size_test = 20
best_accu = 0

img_transform = transforms.Compose([transforms.ToTensor()])

trainset = dataset.noisy_stl10(sigma,
                               num_train=num_train,
                               num_test=num_test,
                               num_copy=num_copy,
                               dataDir=dataDir,
                               transform=img_transform,
                               train=True)
trainloader = DataLoader(trainset, batch_size=batch_size_train, shuffle=True)
testset = dataset.noisy_stl10(sigma,
                              num_train=num_train,
                              num_test=num_test,
                              num_copy=num_copy,
                              dataDir=dataDir,
                              transform=img_transform,
                              train=False)
testloader = DataLoader(testset, batch_size=batch_size_test, shuffle=True)

num_epochs = args.epoch
reg_term = args.reg_term
Example #3
0
def get_denoised_dataset(dataDir, outDir, sigma, netName, num_copy=1):
    checkpoint = torch.load('./checkpoints/ckpt_%s_sigma%.2f_copy%d.t7' %
                            (netName, sigma, num_copy))
    net = models.dncnn.deepcnn(netName[6:]).cuda()
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
    net.load_state_dict(checkpoint['net'])
    device = 'cuda'
    img_transform = transforms.Compose([transforms.ToTensor()])

    batch_size_train = 32
    batch_size_test = 20
    trainset = dataset.noisy_stl10(sigma,
                                   num_copy=num_copy,
                                   dataDir=dataDir,
                                   transform=img_transform,
                                   train=True)
    trainloader = DataLoader(trainset,
                             batch_size=batch_size_train,
                             shuffle=False)
    testset = dataset.noisy_stl10(sigma,
                                  num_copy=num_copy,
                                  dataDir=dataDir,
                                  transform=img_transform,
                                  train=False)
    testloader = DataLoader(testset, batch_size=batch_size_test, shuffle=False)

    train_data_denoised = np.zeros(trainset.train_data.shape)
    test_data_denoised = np.zeros(testset.test_data.shape)

    with torch.no_grad():
        img_idx = 0
        for batch_idx, (noisy, img, targets) in enumerate(trainloader):
            noisy, img, targets = noisy.to(device), img.to(device), targets.to(
                device)
            outputs = net(noisy)  # now N*3*96*96
            outputs = outputs.cpu().detach().numpy().transpose(
                (0, 2, 3, 1))  # now N*96*96*3
            train_data_denoised[batch_idx * batch_size_train:(batch_idx + 1) *
                                batch_size_train] = outputs


#            for i in range(len(outputs)):
#                if img_idx%500 == 0:
#                    out_img = np.clip(outputs[i],0,1)
#                    imsave('%s/images/%s_sigma%.2f_train_denoised_%d.jpg'%(outDir,netName,sigma,img_idx),out_img)
#                img_idx += 1

    with torch.no_grad():
        img_idx = 0
        for batch_idx, (noisy, img, targets) in enumerate(testloader):
            noisy, img, targets = noisy.to(device), img.to(device), targets.to(
                device)
            outputs = net(noisy)  # now N*3*96*96
            outputs = outputs.cpu().detach().numpy().transpose(
                (0, 2, 3, 1))  # now N*96*96*3
            test_data_denoised[batch_idx * batch_size_test:(batch_idx + 1) *
                               batch_size_test] = outputs
            for i in range(len(outputs)):
                if img_idx % 100 == 0:
                    out_img = np.clip(outputs[i], 0, 1)
                    imsave(
                        '%s/images/%s_sigma%.2f_test_denoised_%d.jpg' %
                        (outDir, netName, sigma, img_idx), out_img)
                    img_clean = img.cpu().detach().numpy().transpose(
                        (0, 2, 3, 1))[i]
                    imsave(
                        '%s/images/%s_sigma%.2f_test_img_%d.jpg' %
                        (outDir, netName, sigma, img_idx), img_clean)
                    noisy = noisy.cpu().detach().numpy().transpose(
                        (0, 2, 3, 1))[i]
                    imsave(
                        '%s/images/%s_sigma%.2f_test_noisy_%d.jpg' %
                        (outDir, netName, sigma, img_idx), noisy)
                img_idx += 1

    fileName = '%s/denoisedSTL10_%s_sigma%.2f_copy%d.npz' % (dataDir, netName,
                                                             sigma, num_copy)
    np.savez(fileName,
             train_data_denoised=train_data_denoised,
             train_labels=trainset.train_labels,
             train_data=trainset.train_data,
             train_data_noisy=trainset.train_data_noisy,
             test_data_denoised=test_data_denoised,
             test_labels=testset.test_labels,
             test_data=testset.test_data,
             test_data_noisy=testset.test_data_noisy,
             num_copy=num_copy,
             sigma=sigma,
             netName=netName)