Example #1
0
def main(args):
    r"""Performs the main training loop
	"""
    # Load dataset
    print('> Loading dataset ...')
    dataset_train = Dataset(train=True, gray_mode=args.gray, shuffle=True)
    dataset_val = Dataset(train=False, gray_mode=args.gray, shuffle=False)
    loader_train = DataLoader(dataset=dataset_train, num_workers=6, \
             batch_size=args.batch_size, shuffle=True)
    print("\t# of training samples: %d\n" % int(len(dataset_train)))

    # Init loggers
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    writer = SummaryWriter(args.log_dir)
    logger = init_logger(args)

    # Create model
    if not args.gray:
        in_ch = 3
    else:
        in_ch = 1
    net = FFDNet(num_input_channels=in_ch)
    # Initialize model with He init
    net.apply(weights_init_kaiming)
    # Define loss
    criterion = nn.MSELoss(size_average=False)

    # Move to GPU
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    criterion.cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Resume training or start anew
    if args.resume_training:
        resumef = os.path.join(args.log_dir, 'ckpt.pth')
        if os.path.isfile(resumef):
            checkpoint = torch.load(resumef)
            print("> Resuming previous training")
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            new_epoch = args.epochs
            new_milestone = args.milestone
            current_lr = args.lr
            args = checkpoint['args']
            training_params = checkpoint['training_params']
            start_epoch = training_params['start_epoch']
            args.epochs = new_epoch
            args.milestone = new_milestone
            args.lr = current_lr
            print("=> loaded checkpoint '{}' (epoch {})"\
               .format(resumef, start_epoch))
            print("=> loaded parameters :")
            print("==> checkpoint['optimizer']['param_groups']")
            print("\t{}".format(checkpoint['optimizer']['param_groups']))
            print("==> checkpoint['training_params']")
            for k in checkpoint['training_params']:
                print("\t{}, {}".format(k, checkpoint['training_params'][k]))
            argpri = vars(checkpoint['args'])
            print("==> checkpoint['args']")
            for k in argpri:
                print("\t{}, {}".format(k, argpri[k]))

            args.resume_training = False
        else:
            raise Exception("Couldn't resume training with checkpoint {}".\
                format(resumef))
    else:
        start_epoch = 0
        training_params = {}
        training_params['step'] = 0
        training_params['current_lr'] = 0
        training_params['no_orthog'] = args.no_orthog

    # Training
    for epoch in range(start_epoch, args.epochs):
        # Learning rate value scheduling according to args.milestone
        if epoch > args.milestone[1]:
            current_lr = args.lr / 1000.
            training_params['no_orthog'] = True
        elif epoch > args.milestone[0]:
            current_lr = args.lr / 10.
        else:
            current_lr = args.lr

        # set learning rate in optimizer
        for param_group in optimizer.param_groups:
            param_group["lr"] = current_lr
        print('learning rate %f' % current_lr)

        # train
        for i, data in enumerate(loader_train, 0):
            # Pre-training step
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            # inputs: noise and noisy image
            img_train = data
            noise = torch.zeros(img_train.size())
            stdn = np.random.uniform(args.noiseIntL[0], args.noiseIntL[1], \
                size=noise.size()[0])
            for nx in range(noise.size()[0]):
                sizen = noise[0, :, :, :].size()
                noise[nx, :, :, :] = torch.FloatTensor(sizen).\
                     normal_(mean=0, std=stdn[nx])
            imgn_train = img_train + noise
            # Create input Variables
            img_train = Variable(img_train.cuda(), volatile=True)
            imgn_train = Variable(imgn_train.cuda(), volatile=True)
            noise = Variable(noise.cuda())
            stdn_var = Variable(torch.cuda.FloatTensor(stdn), volatile=True)

            # Evaluate model and optimize it
            out_train = model(imgn_train, stdn_var)
            loss = criterion(out_train, noise) / (imgn_train.size()[0] * 2)
            loss.backward()
            optimizer.step()

            # Results
            model.eval()
            out_train = torch.clamp(imgn_train - model(imgn_train, stdn_var),
                                    0., 1.)
            psnr_train = batch_psnr(out_train, img_train, 1.)
            # PyTorch v0.4.0: loss.data[0] --> loss.item()

            if training_params['step'] % args.save_every == 0:
                # Apply regularization by orthogonalizing filters
                if not training_params['no_orthog']:
                    model.apply(svd_orthogonalization)

                # Log the scalar values
                writer.add_scalar('loss', loss.data[0],
                                  training_params['step'])
                writer.add_scalar('PSNR on training data', psnr_train, \
                   training_params['step'])
                print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" %\
                 (epoch+1, i+1, len(loader_train), loss.data[0], psnr_train))
            training_params['step'] += 1
        # The end of each epoch
        model.eval()

        # Validation
        psnr_val = 0
        for valimg in dataset_val:
            img_val = torch.unsqueeze(valimg, 0)
            noise = torch.FloatTensor(img_val.size()).\
              normal_(mean=0, std=args.val_noiseL)
            imgn_val = img_val + noise
            img_val, imgn_val = Variable(img_val.cuda()), Variable(
                imgn_val.cuda())
            sigma_noise = Variable(torch.cuda.FloatTensor([args.val_noiseL]))
            out_val = torch.clamp(imgn_val - model(imgn_val, sigma_noise), 0.,
                                  1.)
            psnr_val += batch_psnr(out_val, img_val, 1.)
        psnr_val /= len(dataset_val)
        print("\n[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val))
        writer.add_scalar('PSNR on validation data', psnr_val, epoch)
        writer.add_scalar('Learning rate', current_lr, epoch)

        # Log val images
        try:
            if epoch == 0:
                # Log graph of the model
                writer.add_graph(
                    model,
                    (imgn_val, sigma_noise),
                )
                # Log validation images
                for idx in range(2):
                    imclean = utils.make_grid(img_val.data[idx].clamp(0., 1.), \
                          nrow=2, normalize=False, scale_each=False)
                    imnsy = utils.make_grid(imgn_val.data[idx].clamp(0., 1.), \
                          nrow=2, normalize=False, scale_each=False)
                    writer.add_image('Clean validation image {}'.format(idx),
                                     imclean, epoch)
                    writer.add_image('Noisy validation image {}'.format(idx),
                                     imnsy, epoch)
            for idx in range(2):
                imrecons = utils.make_grid(out_val.data[idx].clamp(0., 1.), \
                      nrow=2, normalize=False, scale_each=False)
                writer.add_image('Reconstructed validation image {}'.format(idx), \
                    imrecons, epoch)
            # Log training images
            imclean = utils.make_grid(img_train.data, nrow=8, normalize=True, \
                scale_each=True)
            writer.add_image('Training patches', imclean, epoch)

        except Exception as e:
            logger.error("Couldn't log results: {}".format(e))

        # save model and checkpoint
        training_params['start_epoch'] = epoch + 1
        torch.save(model.state_dict(), os.path.join(args.log_dir, 'net.pth'))
        save_dict = { \
         'state_dict': model.state_dict(), \
         'optimizer' : optimizer.state_dict(), \
         'training_params': training_params, \
         'args': args\
         }
        torch.save(save_dict, os.path.join(args.log_dir, 'ckpt.pth'))
        if epoch % args.save_every_epochs == 0:
            torch.save(save_dict, os.path.join(args.log_dir, \
                    'ckpt_e{}.pth'.format(epoch+1)))
        del save_dict
Example #2
0
def test_ffdnet(**args):
    r"""Denoises an input image with FFDNet
	"""
    # Init logger
    logger = init_logger_ipol()

    # Check if input exists and if it is RGB
    try:
        rgb_den = is_rgb(args['input'])
    except:
        raise Exception('Could not open the input image')

    # Open image as a CxHxW torch.Tensor
    if rgb_den:
        in_ch = 3
        model_fn = 'models/net_rgb.pth'
        imorig = cv2.imread(args['input'])
        # from HxWxC to CxHxW, RGB image
        imorig = (cv2.cvtColor(imorig, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
    else:
        # from HxWxC to  CxHxW grayscale image (C=1)
        in_ch = 1
        model_fn = 'logs/net.pth'
        imorig = cv2.imread(args['input'], cv2.IMREAD_GRAYSCALE)
        imorig_copy = imorig.copy()
        imorig = np.expand_dims(imorig, 0)
    imorig = np.expand_dims(imorig, 0)

    # Handle odd sizes
    expanded_h = False
    expanded_w = False
    sh_im = imorig.shape
    if sh_im[2] % 2 == 1:
        expanded_h = True
        imorig = np.concatenate((imorig, \
          imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)

    if sh_im[3] % 2 == 1:
        expanded_w = True
        imorig = np.concatenate((imorig, \
          imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)

    imorig = normalize(imorig)
    imorig = torch.Tensor(imorig)

    # Absolute path to model file
    model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \
       model_fn)

    # Create model
    print('Loading model ...\n')
    net = FFDNet(num_input_channels=in_ch)

    # Load saved weights
    if args['cuda']:
        state_dict = torch.load(model_fn)
        device_ids = [0]
        model = nn.DataParallel(net, device_ids=device_ids).cuda()
    else:
        state_dict = torch.load(model_fn, map_location='cpu')
        # CPU mode: remove the DataParallel wrapper
        state_dict = remove_dataparallel_wrapper(state_dict)
        model = net
    model.load_state_dict(state_dict)

    # Sets the model in evaluation mode (e.g. it removes BN)
    model.eval()

    # Sets data type according to CPU or GPU modes
    if args['cuda']:
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # Add noise
    if args['add_noise']:
        noise = torch.FloatTensor(imorig.size()).\
          normal_(mean=0, std=args['noise_sigma'])
        imnoisy = imorig + noise
    else:
        imnoisy = imorig.clone()

# Test mode
    with torch.no_grad():  # PyTorch v0.4.0
        imorig, imnoisy = Variable(imorig.type(dtype)), \
            Variable(imnoisy.type(dtype))
        nsigma = Variable(torch.FloatTensor([args['noise_sigma']]).type(dtype))

    # Measure runtime
    start_t = time.time()

    # Estimate noise and subtract it to the input image
    im_noise_estim = model(imnoisy, nsigma)
    outim = torch.clamp(imnoisy - im_noise_estim, 0., 1.)
    stop_t = time.time()

    if expanded_h:
        imorig = imorig[:, :, :-1, :]
        outim = outim[:, :, :-1, :]
        imnoisy = imnoisy[:, :, :-1, :]

    if expanded_w:
        imorig = imorig[:, :, :, :-1]
        outim = outim[:, :, :, :-1]
        imnoisy = imnoisy[:, :, :, :-1]

    # Compute PSNR and log it
    if rgb_den:
        print("### RGB denoising ###")
    else:
        print("### Grayscale denoising ###")
    if args['add_noise']:
        psnr = batch_psnr(outim, imorig, 1.)
        psnr_noisy = batch_psnr(imnoisy, imorig, 1.)

        print("----------PSNR noisy    {0:0.2f}dB".format(psnr_noisy))
        print("----------PSNR denoised {0:0.2f}dB".format(psnr))
    else:
        logger.info("\tNo noise was added, cannot compute PSNR")
    print("----------Runtime     {0:0.4f}s".format(stop_t - start_t))

    # Compute difference
    diffout = 2 * (outim - imorig) + .5
    diffnoise = 2 * (imnoisy - imorig) + .5

    # Save images
    if not args['dont_save_results']:
        noisyimg = variable_to_cv2_image(imnoisy)
        outimg = variable_to_cv2_image(outim)
        cv2.imwrite(
            "bfffd/noisy-" + str(int(args['noise_sigma'] * 255)) + '-' +
            args['input'], noisyimg)
        cv2.imwrite(
            "bfffd/ffdnet-" + str(int(args['noise_sigma'] * 255)) + '-' +
            args['input'], outimg)

        if args['add_noise']:
            cv2.imwrite("noisy_diff.png", variable_to_cv2_image(diffnoise))
            cv2.imwrite("ffdnet_diff.png", variable_to_cv2_image(diffout))
    (score, diff) = compare_ssim(noisyimg, imorig_copy, full=True)
    (score2, diff) = compare_ssim(outimg, imorig_copy, full=True)
    print("----------Noisy ssim:    {0:0.4f}".format(score))
    print("----------Denoisy ssim: {0:0.4f}".format(score2))
Example #3
0
# cuda = True
cuda = False
# Only Grayscale model is provided here, the RGB model is even better.
rgb_den = False
B = 2048.
# B = 256.
########
in_ch = 1
# This model was trained on the Waterloo Exploration Database, sigma: [0,20], 50 epochs. see Fig.2 of the paper
# It cost about 12 hours in a GPU server (Nvidia RTX 2080Ti with 11GB graphic memory)
model_fn = 'mat/net-gray-v2.pth'
# Absolute path to model file
model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_fn)
# Create model
print('Loading model ...\n')
net = FFDNet(num_input_channels=in_ch)

# Load saved weights
if cuda:
    state_dict = torch.load(model_fn)
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    # Sets data type according to CPU or GPU modes
    dtype = torch.cuda.FloatTensor
else:
    state_dict = torch.load(model_fn, map_location='cpu')
    # CPU mode: remove the DataParallel wrapper
    state_dict = remove_dataparallel_wrapper(state_dict)
    model = net
    dtype = torch.FloatTensor
model.load_state_dict(state_dict)
Example #4
0
def test_ffdnet(**args):
    r"""Denoises an input image with FFDNet
	"""
    # Init logger
    logger = init_logger_ipol()

    # Check if input exists and if it is RGB
    try:
        rgb_den = is_rgb(args['input'])
    except:
        raise Exception('Could not open the input image')

# Measure runtime
    start_t = time.time()

    # Open image as a CxHxW torch.Tensor
    if rgb_den:
        in_ch = 3
        model_fn = 'net_rgb.pth'
        imorig = Image.open(args['input'])
        imorig = np.array(imorig, dtype=np.float32).transpose(2, 0, 1)
    else:
        # from HxWxC to  CxHxW grayscale image (C=1)
        in_ch = 1
        model_fn = 'models/net_gray.pth'
        # imorig = cv2.imread(args['input'], cv2.IMREAD_GRAYSCALE)
        imorig = np.expand_dims(imorig, 0)
    imorig = np.expand_dims(imorig, 0)

    # Handle odd sizes
    expanded_h = False
    expanded_w = False
    sh_im = imorig.shape
    if sh_im[2] % 2 == 1:
        expanded_h = True
        imorig = np.concatenate((imorig, \
          imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)

    if sh_im[3] % 2 == 1:
        expanded_w = True
        imorig = np.concatenate((imorig, \
          imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)

    imorig = normalize(imorig)
    imorig = torch.Tensor(imorig)
    # Absolute path to model file
    model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \
       model_fn)

    # Create model
    print('Loading model ...\n')
    net = FFDNet(num_input_channels=in_ch)

    # Load saved weights
    if args['cuda']:
        state_dict = torch.load(model_fn)
        #device_ids = [0,1,2,3]
        #model = nn.DataParallel(net, device_ids=device_ids).cuda()
        #state_dict = remove_dataparallel_wrapper(state_dict)
        model = net
    else:
        state_dict = torch.load(model_fn, map_location='cpu')
        # CPU mode: remove the DataParallel wrapper
        state_dict = remove_dataparallel_wrapper(state_dict)
        model = net
    model.load_state_dict(state_dict)

    # Sets the model in evaluation mode (e.g. it removes BN)
    model.eval()

    # Sets data type according to CPU or GPU modes
    if args['cuda']:
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

# Test mode
    with torch.no_grad():  # PyTorch v0.4.0
        imorig = Variable(imorig.type(dtype))
        nsigma = Variable(torch.FloatTensor([args['noise_sigma']]).type(dtype))

    # # Measure runtime
    # start_t = time.time()

    # Estimate noise and subtract it to the input image
    im_noise_estim = model(imorig, nsigma)
    stop_t = time.time()

    # log time
    if rgb_den:
        print("### RGB denoising ###")
    else:
        print("### Grayscale denoising ###")
    print("\tRuntime {0:0.4f}s".format(stop_t - start_t))

    # Save noises
    noise = variable_to_numpy(imorig.to(3) - im_noise_estim).transpose(1, 2, 0)
    filename = args['input'].split('/')[-1].split('.')[0]
    if args['save_path']:
        sio.savemat(
            './output_noise/' + args['save_path'] + '/' + filename + '.mat',
            {'Noisex': noise})
    else:
        sio.savemat('./output_noise/' + filename + '.mat', {'Noisex': noise})
Example #5
0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
r"""Denoises an input image with FFDNet
"""
# Init logger
logger = init_logger_ipol()

# Check if input exists and if it is RGB
# Absolute path to model file
model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \
                        'models/net_rgb.pth')

# Create model
print('Loading model ...\n')
net = FFDNet(num_input_channels=3, test_mode=True)
model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \
                        model_fn)

# Load saved weights
print(model_fn)
state_dict = torch.load(model_fn)
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids).cuda()
model.load_state_dict(state_dict)

# Sets the model in evaluation mode (e.g. it removes BN)
model.eval()

# Sets data type according to CPU or GPU modes
dtype = torch.cuda.FloatTensor