Exemplo n.º 1
0
def init_weights(net, init_type='xavier_uniform', init_bn_type='uniform', gain=1):
    """
    # Kai Zhang, https://github.com/cszn/KAIR
    #
    # Args:
    #   init_type:
    #       normal; normal; xavier_normal; xavier_uniform;
    #       kaiming_normal; kaiming_uniform; orthogonal
    #   init_bn_type:
    #       uniform; constant
    #   gain:
    #       0.2
    """
    print('Initialization method [{:s} + {:s}], gain is [{:.2f}]'.format(init_type, init_bn_type, gain))

    def init_fn(m, init_type='xavier_uniform', init_bn_type='uniform', gain=1):
        classname = m.__class__.__name__

        if classname.find('Conv') != -1 or classname.find('Linear') != -1:

            if init_type == 'normal':
                init.normal_(m.weight.data, 0, 0.1)
                m.weight.data.clamp_(-1, 1).mul_(gain)

            elif init_type == 'uniform':
                init.uniform_(m.weight.data, -0.2, 0.2)
                m.weight.data.mul_(gain)

            elif init_type == 'xavier_normal':
                init.xavier_normal_(m.weight.data, gain=gain)
                m.weight.data.clamp_(-1, 1)

            elif init_type == 'xavier_uniform':
                init.xavier_uniform_(m.weight.data, gain=gain)

            elif init_type == 'kaiming_normal':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
                m.weight.data.clamp_(-1, 1).mul_(gain)

            elif init_type == 'kaiming_uniform':
                init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
                m.weight.data.mul_(gain)

            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)

            else:
                raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_type))

            if m.bias is not None:
                m.bias.data.zero_()

        elif classname.find('BatchNorm2d') != -1:

            if init_bn_type == 'uniform':  # preferred
                if m.affine:
                    init.uniform_(m.weight.data, 0.1, 1.0)
                    init.constant_(m.bias.data, 0.0)
            elif init_bn_type == 'constant':
                if m.affine:
                    init.constant_(m.weight.data, 1.0)
                    init.constant_(m.bias.data, 0.0)
            else:
                raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_bn_type))

    fn = functools.partial(init_fn, init_type=init_type, init_bn_type=init_bn_type, gain=gain)
    net.apply(fn)
Exemplo n.º 2
0
sigma = 50

test_im = # complete
train_im = # complete

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.MSELoss(reduction='sum')
lr = 1e-5
epochs = 50

model_path = os.path.join(model_pool, model_name+'.pth')
test_path = os.path.join(test_im)
train_path = os.path.join(train_im)

# load model
model = DnCNN(in_nc=n_channels, out_nc=n_channels, nc=64, nb=17, act_mode='R')
model.load_state_dict(torch.load(model_path), strict=True)
model = model.to(device)
model.eval()

# load test image  
x = util.imread_uint(test_path, n_channels=n_channels)
orig_im = x.squeeze()
x = util.uint2single(x)
np.random.seed(seed=0)  # for reproducibility
y = x + np.random.normal(0, sigma/255., x.shape) # add gaussian noise
y = util.single2tensor4(y)
y = y.to(device)

# denoise the image to compare PSNR before and after adaptation
with torch.no_grad():
Exemplo n.º 3
0
def denoising_gray(noise_im, clean_im, LR=1e-2, sigma=5, rho=1, eta=0.5, 
                   total_step=20, prob1_iter=1000, result_root=None, fo=None):
    input_depth = 1
    latent_dim = 1
    
    en_net = UNet(input_depth, latent_dim, need_sigmoid = False).cuda()
    de_net = UNet(latent_dim, input_depth, need_sigmoid = False).cuda()
    
    model = DnCNN(1, 1, nc=64, nb=20, act_mode='R')
    model_path = './model_zoo/dncnn_gray_blind.pth'
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.cuda()
    for k, v in model.named_parameters():
        v.requires_grad = False

    en_optimizer = torch.optim.Adam(en_net.parameters(), lr = LR)
    de_optimizer = torch.optim.Adam(de_net.parameters(), lr = LR)
    
    l2_loss = torch.nn.MSELoss().cuda()
    
    i0 = np_to_torch(noise_im).to(device)
    noise_im_torch = np_to_torch(noise_im).to(device)
    Y = torch.zeros_like(noise_im_torch).to(device)
    i0_til_torch = np_to_torch(noise_im).to(device)
    
    best_psnr = 0
    best_ssim = 0
    
    for i in range(total_step):
        
############################### sub-problem 1 #################################

        prob1_iter = iteration_decay(prob1_iter)

        for i_1 in range(prob1_iter):
            mean = en_net(noise_im_torch)

            eps = mean.clone().normal_()
            out = de_net(mean + eps)

            total_loss =  0.5 * l2_loss(out, noise_im_torch)
            total_loss += 0.5 * 1/(sigma**2) * l2_loss(mean, i0)
            
            en_optimizer.zero_grad()
            de_optimizer.zero_grad()

            total_loss.backward()

            en_optimizer.step()
            de_optimizer.step()
            
            with torch.no_grad():
                i0 = ((1/sigma**2)*mean + rho*(i0_til_torch - Y)) / ((1/sigma**2) + rho)
        
        with torch.no_grad():
            
############################### sub-problem 2 #################################
            
            i0_til_torch = model((i0+Y)/255) * 255

############################### sub-problem 3 #################################

            Y = Y + eta * (i0 - i0_til_torch)

###############################################################################

            i0_til_np = torch_to_np(i0_til_torch).clip(0, 255)            
            psnr_gt = compare_psnr(clean_im, i0_til_np, 255)
            ssim_gt = compare_ssim(i0_til_np, clean_im, multichannel=False, data_range=255)
            
            denoise_obj_name = 'denoise_obj_{:04d}'.format(i) + '.png'            
            i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png'
            result_name = 'num_epoch_{:04d}'.format(i) + '.png'
            mean_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png'
            out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png'
            
            save_torch(Y+i0, result_root + denoise_obj_name)
            save_torch(i0, result_root + i0_name)
            save_torch(i0_til_torch, result_root + result_name)
            save_torch(mean, result_root + mean_name)
            save_torch(out, result_root + out_name)

            print('Iteration %02d  Loss %f  PSNR_gt: %f, SSIM_gt: %f' % (i, total_loss.item(), psnr_gt, ssim_gt), file=fo, flush=True)
            
            if best_psnr < psnr_gt:
                best_psnr = psnr_gt
                best_ssim = ssim_gt
            else:
                break
    
    return best_psnr, best_ssim