Exemple #1
0
def validate_and_log(model_temp, dataset_val, valnoisestd, temp_psz, writer, \
      epoch, lr, logger, trainimg):
    """Validation step after the epoch finished
	"""
    t1 = time.time()
    psnr_val = 0
    with torch.no_grad():
        for seq_val in dataset_val:
            noise = torch.FloatTensor(seq_val.size()).normal_(mean=0,
                                                              std=valnoisestd)
            seqn_val = seq_val + noise
            seqn_val = seqn_val.cuda()
            sigma_noise = torch.cuda.FloatTensor([valnoisestd])
            out_val = denoise_seq_fastdvdnet(seq=seqn_val, \
                    noise_std=sigma_noise, \
                    temp_psz=temp_psz,\
                    model_temporal=model_temp)
            psnr_val += batch_psnr(out_val.cpu(), seq_val.squeeze_(), 1.)
        psnr_val /= len(dataset_val)
        t2 = time.time()
        print("\n[epoch %d] PSNR_val: %.4f, on %.2f sec" %
              (epoch + 1, psnr_val, (t2 - t1)))
        writer.add_scalar('PSNR on validation data', psnr_val, epoch)
        writer.add_scalar('Learning rate', lr, epoch)

    # Log val images
    try:
        idx = 0
        if epoch == 0:

            # Log training images
            _, _, Ht, Wt = trainimg.size()
            img = tutils.make_grid(trainimg.view(-1, 3, Ht, Wt), \
                    nrow=8, normalize=True, scale_each=True)
            writer.add_image('Training patches', img, epoch)

            # Log validation images
            img = tutils.make_grid(seq_val.data[idx].clamp(0., 1.),\
                  nrow=2, normalize=False, scale_each=False)
            imgn = tutils.make_grid(seqn_val.data[idx].clamp(0., 1.),\
                  nrow=2, normalize=False, scale_each=False)
            writer.add_image('Clean validation image {}'.format(idx), img,
                             epoch)
            writer.add_image('Noisy validation image {}'.format(idx), imgn,
                             epoch)

        # Log validation results
        irecon = tutils.make_grid(out_val.data[idx].clamp(0., 1.),\
              nrow=2, normalize=False, scale_each=False)
        writer.add_image('Reconstructed validation image {}'.format(idx),
                         irecon, epoch)

    except Exception as e:
        logger.error(
            "validate_and_log_temporal(): Couldn't log results, {}".format(e))
def	log_train_psnr(result, imsource, loss, writer, epoch, idx, num_minibatches, training_params):
	'''Logs trai loss.
	'''
	#Compute pnsr of the whole batch
	psnr_train = batch_psnr(torch.clamp(result, 0., 1.), imsource, 1.)

	# Log the scalar values
	writer.add_scalar('loss', loss.item(), training_params['step'])
# 	writer.add_scalar('PSNR on training data', psnr_train, \
# 		  training_params['step'])
	print("[epoch {}][{}/{}] loss: {:1.4f} PSNR_train: {:1.4f}".\
		  format(epoch+1, idx+1, num_minibatches, loss.item(), psnr_train))
Exemple #3
0
def main(args):
    r"""Performs the main training loop
	"""
    # Load dataset
    print('> Loading dataset ...')
    dataset_train = Dataset(train=True, gray_mode=args.gray, shuffle=True)
    dataset_val = Dataset(train=False, gray_mode=args.gray, shuffle=False)
    loader_train = DataLoader(dataset=dataset_train, num_workers=6, \
             batch_size=args.batch_size, shuffle=True)
    print("\t# of training samples: %d\n" % int(len(dataset_train)))

    # Init loggers
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    writer = SummaryWriter(args.log_dir)
    logger = init_logger(args)

    # Create model
    if not args.gray:
        in_ch = 3
    else:
        in_ch = 1
    net = FFDNet(num_input_channels=in_ch)
    # Initialize model with He init
    net.apply(weights_init_kaiming)
    # Define loss
    criterion = nn.MSELoss(size_average=False)

    # Move to GPU
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    criterion.cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Resume training or start anew
    if args.resume_training:
        resumef = os.path.join(args.log_dir, 'ckpt.pth')
        if os.path.isfile(resumef):
            checkpoint = torch.load(resumef)
            print("> Resuming previous training")
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            new_epoch = args.epochs
            new_milestone = args.milestone
            current_lr = args.lr
            args = checkpoint['args']
            training_params = checkpoint['training_params']
            start_epoch = training_params['start_epoch']
            args.epochs = new_epoch
            args.milestone = new_milestone
            args.lr = current_lr
            print("=> loaded checkpoint '{}' (epoch {})"\
               .format(resumef, start_epoch))
            print("=> loaded parameters :")
            print("==> checkpoint['optimizer']['param_groups']")
            print("\t{}".format(checkpoint['optimizer']['param_groups']))
            print("==> checkpoint['training_params']")
            for k in checkpoint['training_params']:
                print("\t{}, {}".format(k, checkpoint['training_params'][k]))
            argpri = vars(checkpoint['args'])
            print("==> checkpoint['args']")
            for k in argpri:
                print("\t{}, {}".format(k, argpri[k]))

            args.resume_training = False
        else:
            raise Exception("Couldn't resume training with checkpoint {}".\
                format(resumef))
    else:
        start_epoch = 0
        training_params = {}
        training_params['step'] = 0
        training_params['current_lr'] = 0
        training_params['no_orthog'] = args.no_orthog

    # Training
    for epoch in range(start_epoch, args.epochs):
        # Learning rate value scheduling according to args.milestone
        if epoch > args.milestone[1]:
            current_lr = args.lr / 1000.
            training_params['no_orthog'] = True
        elif epoch > args.milestone[0]:
            current_lr = args.lr / 10.
        else:
            current_lr = args.lr

        # set learning rate in optimizer
        for param_group in optimizer.param_groups:
            param_group["lr"] = current_lr
        print('learning rate %f' % current_lr)

        # train
        for i, data in enumerate(loader_train, 0):
            # Pre-training step
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            # inputs: noise and noisy image
            img_train = data
            noise = torch.zeros(img_train.size())
            stdn = np.random.uniform(args.noiseIntL[0], args.noiseIntL[1], \
                size=noise.size()[0])
            for nx in range(noise.size()[0]):
                sizen = noise[0, :, :, :].size()
                noise[nx, :, :, :] = torch.FloatTensor(sizen).\
                     normal_(mean=0, std=stdn[nx])
            imgn_train = img_train + noise
            # Create input Variables
            img_train = Variable(img_train.cuda(), volatile=True)
            imgn_train = Variable(imgn_train.cuda(), volatile=True)
            noise = Variable(noise.cuda())
            stdn_var = Variable(torch.cuda.FloatTensor(stdn), volatile=True)

            # Evaluate model and optimize it
            out_train = model(imgn_train, stdn_var)
            loss = criterion(out_train, noise) / (imgn_train.size()[0] * 2)
            loss.backward()
            optimizer.step()

            # Results
            model.eval()
            out_train = torch.clamp(imgn_train - model(imgn_train, stdn_var),
                                    0., 1.)
            psnr_train = batch_psnr(out_train, img_train, 1.)
            # PyTorch v0.4.0: loss.data[0] --> loss.item()

            if training_params['step'] % args.save_every == 0:
                # Apply regularization by orthogonalizing filters
                if not training_params['no_orthog']:
                    model.apply(svd_orthogonalization)

                # Log the scalar values
                writer.add_scalar('loss', loss.data[0],
                                  training_params['step'])
                writer.add_scalar('PSNR on training data', psnr_train, \
                   training_params['step'])
                print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" %\
                 (epoch+1, i+1, len(loader_train), loss.data[0], psnr_train))
            training_params['step'] += 1
        # The end of each epoch
        model.eval()

        # Validation
        psnr_val = 0
        for valimg in dataset_val:
            img_val = torch.unsqueeze(valimg, 0)
            noise = torch.FloatTensor(img_val.size()).\
              normal_(mean=0, std=args.val_noiseL)
            imgn_val = img_val + noise
            img_val, imgn_val = Variable(img_val.cuda()), Variable(
                imgn_val.cuda())
            sigma_noise = Variable(torch.cuda.FloatTensor([args.val_noiseL]))
            out_val = torch.clamp(imgn_val - model(imgn_val, sigma_noise), 0.,
                                  1.)
            psnr_val += batch_psnr(out_val, img_val, 1.)
        psnr_val /= len(dataset_val)
        print("\n[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val))
        writer.add_scalar('PSNR on validation data', psnr_val, epoch)
        writer.add_scalar('Learning rate', current_lr, epoch)

        # Log val images
        try:
            if epoch == 0:
                # Log graph of the model
                writer.add_graph(
                    model,
                    (imgn_val, sigma_noise),
                )
                # Log validation images
                for idx in range(2):
                    imclean = utils.make_grid(img_val.data[idx].clamp(0., 1.), \
                          nrow=2, normalize=False, scale_each=False)
                    imnsy = utils.make_grid(imgn_val.data[idx].clamp(0., 1.), \
                          nrow=2, normalize=False, scale_each=False)
                    writer.add_image('Clean validation image {}'.format(idx),
                                     imclean, epoch)
                    writer.add_image('Noisy validation image {}'.format(idx),
                                     imnsy, epoch)
            for idx in range(2):
                imrecons = utils.make_grid(out_val.data[idx].clamp(0., 1.), \
                      nrow=2, normalize=False, scale_each=False)
                writer.add_image('Reconstructed validation image {}'.format(idx), \
                    imrecons, epoch)
            # Log training images
            imclean = utils.make_grid(img_train.data, nrow=8, normalize=True, \
                scale_each=True)
            writer.add_image('Training patches', imclean, epoch)

        except Exception as e:
            logger.error("Couldn't log results: {}".format(e))

        # save model and checkpoint
        training_params['start_epoch'] = epoch + 1
        torch.save(model.state_dict(), os.path.join(args.log_dir, 'net.pth'))
        save_dict = { \
         'state_dict': model.state_dict(), \
         'optimizer' : optimizer.state_dict(), \
         'training_params': training_params, \
         'args': args\
         }
        torch.save(save_dict, os.path.join(args.log_dir, 'ckpt.pth'))
        if epoch % args.save_every_epochs == 0:
            torch.save(save_dict, os.path.join(args.log_dir, \
                    'ckpt_e{}.pth'.format(epoch+1)))
        del save_dict
Exemple #4
0
def test_dvdnet(**args):
	"""Denoises all sequences present in a given folder. Sequences must be stored as numbered
	image sequences. The different sequences must be stored in subfolders under the "test_path" folder.

	Inputs:
		args (dict) fields:
			"model_spatial_file": path to model of the pretrained spatial denoiser
			"model_temp_file": path to model of the pretrained temporal denoiser
			"test_path": path to sequence to denoise
			"suffix": suffix to add to output name
			"max_num_fr_per_seq": max number of frames to load per sequence
			"noise_sigma": noise level used on test set
			"dont_save_results: if True, don't save output images
			"no_gpu": if True, run model on CPU
			"save_path": where to save outputs as png
	"""
	start_time = time.time()

	# If save_path does not exist, create it
	if not os.path.exists(args['save_path']):
		os.makedirs(args['save_path'])
	logger = init_logger_test(args['save_path'])

	# Sets data type according to CPU or GPU modes
	if args['cuda']:
		device = torch.device('cuda')
	else:
		device = torch.device('cpu')

	# Create models
	model_spa = DVDnet_spatial()
	model_temp = DVDnet_temporal(num_input_frames=NUM_IN_FRAMES)

	# Load saved weights
	state_spatial_dict = torch.load(args['model_spatial_file'])
	state_temp_dict = torch.load(args['model_temp_file'])
	if args['cuda']:
		device_ids = [0]
		model_spa = nn.DataParallel(model_spa, device_ids=device_ids).cuda()
		model_temp = nn.DataParallel(model_temp, device_ids=device_ids).cuda()
	else:
		# CPU mode: remove the DataParallel wrapper
		state_spatial_dict = remove_dataparallel_wrapper(state_spatial_dict)
		state_temp_dict = remove_dataparallel_wrapper(state_temp_dict)
	model_spa.load_state_dict(state_spatial_dict)
	model_temp.load_state_dict(state_temp_dict)

	# Sets the model in evaluation mode (e.g. it removes BN)
	model_spa.eval()
	model_temp.eval()

	with torch.no_grad():
		# process data
		seq, _, _ = open_sequence(args['test_path'],\
									False,\
									expand_if_needed=False,\
									max_num_fr=args['max_num_fr_per_seq'])
		seq = torch.from_numpy(seq[:, np.newaxis, :, :, :]).to(device)

		seqload_time = time.time()

		# Add noise
		noise = torch.empty_like(seq).normal_(mean=0, std=args['noise_sigma']).to(device)
		seqn = seq + noise
		noisestd = torch.FloatTensor([args['noise_sigma']]).to(device)

		denframes = denoise_seq_dvdnet(seq=seqn,\
										noise_std=noisestd,\
										temp_psz=NUM_IN_FRAMES,\
										model_temporal=model_temp,\
										model_spatial=model_spa,\
										mc_algo=MC_ALGO)
		den_time = time.time()

	# Compute PSNR and log it
	psnr = batch_psnr(denframes, seq.squeeze(), 1.)
	psnr_noisy = batch_psnr(seqn.squeeze(), seq.squeeze(), 1.)
	print("\tPSNR on {} : {}\n".format(os.path.split(args['test_path'])[-1], psnr))
	print("\tDenoising time: {:.2f}s".format(den_time - seqload_time))
	print("\tSequence loaded in : {:.2f}s".format(seqload_time - start_time))
	print("\tTotal time: {:.2f}s\n".format(den_time - start_time))
	logger.info("%s, %s, PSNR noisy %fdB, PSNR %f dB" % \
			 (args['test_path'], args['suffix'], psnr_noisy, psnr))

	# Save outputs
	if not args['dont_save_results']:
		# Save sequence
		save_out_seq(seqn, denframes, args['save_path'], int(args['noise_sigma']*255), \
					   args['suffix'], args['save_noisy'])

	# close logger
	close_logger(logger)
Exemple #5
0
def test_ffdnet(**args):
    r"""Denoises an input image with FFDNet
	"""
    # Init logger
    logger = init_logger_ipol()

    # Check if input exists and if it is RGB
    try:
        rgb_den = is_rgb(args['input'])
    except:
        raise Exception('Could not open the input image')

    # Open image as a CxHxW torch.Tensor
    if rgb_den:
        in_ch = 3
        model_fn = 'models/net_rgb.pth'
        imorig = cv2.imread(args['input'])
        # from HxWxC to CxHxW, RGB image
        imorig = (cv2.cvtColor(imorig, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
    else:
        # from HxWxC to  CxHxW grayscale image (C=1)
        in_ch = 1
        model_fn = 'logs/net.pth'
        imorig = cv2.imread(args['input'], cv2.IMREAD_GRAYSCALE)
        imorig_copy = imorig.copy()
        imorig = np.expand_dims(imorig, 0)
    imorig = np.expand_dims(imorig, 0)

    # Handle odd sizes
    expanded_h = False
    expanded_w = False
    sh_im = imorig.shape
    if sh_im[2] % 2 == 1:
        expanded_h = True
        imorig = np.concatenate((imorig, \
          imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)

    if sh_im[3] % 2 == 1:
        expanded_w = True
        imorig = np.concatenate((imorig, \
          imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)

    imorig = normalize(imorig)
    imorig = torch.Tensor(imorig)

    # Absolute path to model file
    model_fn = os.path.join(os.path.abspath(os.path.dirname(__file__)), \
       model_fn)

    # Create model
    print('Loading model ...\n')
    net = FFDNet(num_input_channels=in_ch)

    # Load saved weights
    if args['cuda']:
        state_dict = torch.load(model_fn)
        device_ids = [0]
        model = nn.DataParallel(net, device_ids=device_ids).cuda()
    else:
        state_dict = torch.load(model_fn, map_location='cpu')
        # CPU mode: remove the DataParallel wrapper
        state_dict = remove_dataparallel_wrapper(state_dict)
        model = net
    model.load_state_dict(state_dict)

    # Sets the model in evaluation mode (e.g. it removes BN)
    model.eval()

    # Sets data type according to CPU or GPU modes
    if args['cuda']:
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # Add noise
    if args['add_noise']:
        noise = torch.FloatTensor(imorig.size()).\
          normal_(mean=0, std=args['noise_sigma'])
        imnoisy = imorig + noise
    else:
        imnoisy = imorig.clone()

# Test mode
    with torch.no_grad():  # PyTorch v0.4.0
        imorig, imnoisy = Variable(imorig.type(dtype)), \
            Variable(imnoisy.type(dtype))
        nsigma = Variable(torch.FloatTensor([args['noise_sigma']]).type(dtype))

    # Measure runtime
    start_t = time.time()

    # Estimate noise and subtract it to the input image
    im_noise_estim = model(imnoisy, nsigma)
    outim = torch.clamp(imnoisy - im_noise_estim, 0., 1.)
    stop_t = time.time()

    if expanded_h:
        imorig = imorig[:, :, :-1, :]
        outim = outim[:, :, :-1, :]
        imnoisy = imnoisy[:, :, :-1, :]

    if expanded_w:
        imorig = imorig[:, :, :, :-1]
        outim = outim[:, :, :, :-1]
        imnoisy = imnoisy[:, :, :, :-1]

    # Compute PSNR and log it
    if rgb_den:
        print("### RGB denoising ###")
    else:
        print("### Grayscale denoising ###")
    if args['add_noise']:
        psnr = batch_psnr(outim, imorig, 1.)
        psnr_noisy = batch_psnr(imnoisy, imorig, 1.)

        print("----------PSNR noisy    {0:0.2f}dB".format(psnr_noisy))
        print("----------PSNR denoised {0:0.2f}dB".format(psnr))
    else:
        logger.info("\tNo noise was added, cannot compute PSNR")
    print("----------Runtime     {0:0.4f}s".format(stop_t - start_t))

    # Compute difference
    diffout = 2 * (outim - imorig) + .5
    diffnoise = 2 * (imnoisy - imorig) + .5

    # Save images
    if not args['dont_save_results']:
        noisyimg = variable_to_cv2_image(imnoisy)
        outimg = variable_to_cv2_image(outim)
        cv2.imwrite(
            "bfffd/noisy-" + str(int(args['noise_sigma'] * 255)) + '-' +
            args['input'], noisyimg)
        cv2.imwrite(
            "bfffd/ffdnet-" + str(int(args['noise_sigma'] * 255)) + '-' +
            args['input'], outimg)

        if args['add_noise']:
            cv2.imwrite("noisy_diff.png", variable_to_cv2_image(diffnoise))
            cv2.imwrite("ffdnet_diff.png", variable_to_cv2_image(diffout))
    (score, diff) = compare_ssim(noisyimg, imorig_copy, full=True)
    (score2, diff) = compare_ssim(outimg, imorig_copy, full=True)
    print("----------Noisy ssim:    {0:0.4f}".format(score))
    print("----------Denoisy ssim: {0:0.4f}".format(score2))
Exemple #6
0
def test_fastdvdnet(**args):
    """Denoises all sequences present in a given folder. Sequences must be stored as numbered
	image sequences. The different sequences must be stored in subfolders under the "test_path" folder.

	Inputs:
		args (dict) fields:
			"model_file": path to model
			"test_path": path to sequence to denoise
			"suffix": suffix to add to output name
			"max_num_fr_per_seq": max number of frames to load per sequence
			"noise_sigma": noise level used on test set
			"dont_save_results: if True, don't save output images
			"no_gpu": if True, run model on CPU
			"save_path": where to save outputs as png
			"gray": if True, perform denoising of grayscale images instead of RGB
	"""
    # Start time
    start_time = time.time()

    # If save_path does not exist, create it
    if not os.path.exists(args['save_path']):
        os.makedirs(args['save_path'])
    logger = init_logger_test(args['save_path'])

    # Sets data type according to CPU or GPU modes
    if args['cuda']:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Create models
    print('Loading models ...')
    model_temp = FastDVDnet(num_input_frames=NUM_IN_FR_EXT)

    # Load saved weights
    state_temp_dict = torch.load(args['model_file'], map_location=device)
    if args['cuda']:
        device_ids = [0]
        model_temp = nn.DataParallel(model_temp, device_ids=device_ids).cuda()
    else:
        # CPU mode: remove the DataParallel wrapper
        state_temp_dict = remove_dataparallel_wrapper(state_temp_dict)
    model_temp.load_state_dict(state_temp_dict)

    # Sets the model in evaluation mode (e.g. it removes BN)
    model_temp.eval()

    with torch.no_grad():
        # process data
        seq, _, _ = open_sequence(args['test_path'],\
               args['gray'],\
               expand_if_needed=False,\
               max_num_fr=args['max_num_fr_per_seq'])
        seq = torch.from_numpy(seq).to(device)
        seq_time = time.time()

        # Add noise
        noise = torch.empty_like(seq).normal_(
            mean=0, std=args['noise_sigma']).to(device)
        seqn = seq + noise
        noisestd = torch.FloatTensor([args['noise_sigma']]).to(device)

        denframes = denoise_seq_fastdvdnet(seq=seqn,\
                noise_std=noisestd,\
                temp_psz=NUM_IN_FR_EXT,\
                model_temporal=model_temp)

    # Compute PSNR and log it
    stop_time = time.time()
    psnr = batch_psnr(denframes, seq, 1.)
    psnr_noisy = batch_psnr(seqn.squeeze(), seq, 1.)
    loadtime = (seq_time - start_time)
    runtime = (stop_time - seq_time)
    seq_length = seq.size()[0]
    logger.info("Finished denoising {}".format(args['test_path']))
    logger.info("\tDenoised {} frames in {:.3f}s, loaded seq in {:.3f}s".\
        format(seq_length, runtime, loadtime))
    logger.info("\tPSNR noisy {:.4f}dB, PSNR result {:.4f}dB".format(
        psnr_noisy, psnr))

    # Save outputs
    if not args['dont_save_results']:
        # Save sequence
        save_out_seq(seqn, denframes, args['save_path'], \
              int(args['noise_sigma']*255), args['suffix'], args['save_noisy'])

    # close logger
    close_logger(logger)
Exemple #7
0
def train(args: opts.TrainingOptions) -> None:
    # Create the loader:
    training_data = datsetprocess.AHDRDataset(
        scene_directory=args.training_data)
    loader = data.DataLoader(training_data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=1)
    # Create the model:
    model = models.AHDRNet().cuda()
    criterion = nn.L1Loss().to(args.device.value)
    optimizer = optim.Adam(model.parameters(), lr=args.learn_rate)

    loss_list = []
    current_epoch = 0
    # Load pre-train model:
    if os.path.exists(args.checkpoint):
        checkpoint_file = os.path.realpath(args.checkpoint)
        basename, _ = os.path.splitext(os.path.basename(checkpoint_file))
        current_epoch = int(basename) + 1
        state = torch.load(args.checkpoint)
        loss_list = state['loss_list']
        model.load_state_dict(state['model'])

    # Train:
    progress_bar = tqdm.tqdm(
        iterable=range(current_epoch, args.max_epoch),
        desc='STEP ?/? | LOSS ?.?????? | PSNR ?.????',
        total=args.max_epoch,
        initial=current_epoch,
    )

    for epoch in progress_bar:
        losses = []
        for step, sample in enumerate(loader):
            (batch_x1, batch_x2, batch_x3, batch_x4) = (
                sample['input1'],
                sample['input2'],
                sample['input3'],
                sample['label'],
            )
            (batch_x1, batch_x2, batch_x3, batch_x4) = (
                autograd.Variable(batch_x1).cuda(),
                autograd.Variable(batch_x2).cuda(),
                autograd.Variable(batch_x3).cuda(),
                autograd.Variable(batch_x4).cuda(),
            )

            # Forward and compute loss:
            pre = model(batch_x1, batch_x2, batch_x3)
            loss = criterion(pre, batch_x4)
            psnr = utils.batch_psnr(torch.clamp(pre, 0., 1.), batch_x4, 1.0)
            losses.append(loss.item())
            progress_bar.set_description(
                'STEP {}/{} | LOSS {:0.6f} | PSNR {:0.6f}'.format(
                    step + 1, len(loader), losses[-1], psnr))

            # Update the parameters:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss_list.append(np.mean(losses))

        # Save the training model.
        if epoch > 0 and epoch % args.checkpoint_interval == 0:
            # Save progress:
            save_file = '{:06d}.pkl'.format(epoch)
            save_path = os.path.join(args.checkpoint_directory, save_file)
            torch.save({
                'model': model.state_dict(),
                'loss_list': loss_list,
            }, save_path)

            # Create a symlink for easy resume:
            latest = os.path.join(args.checkpoint_directory, 'latest.pkl')
            if os.path.exists(latest):
                os.unlink(latest)
            os.symlink(save_file, latest)
Exemple #8
0
def test(args):
    # Image
    image = cv2.imread(args.test_path)
    if image is None:
        raise Exception(f'File {args.test_path} not found or error')
    is_gray = utils.is_image_gray(image)
    image = read_image(args.test_path, is_gray)
    print("{} image shape: {}".format("Gray" if is_gray else "RGB",
                                      image.shape))

    # Expand odd shape to even
    expend_W = False
    expend_H = False
    if image.shape[1] % 2 != 0:
        expend_W = True
        image = np.concatenate((image, image[:, -1, :][:, np.newaxis, :]),
                               axis=1)
    if image.shape[2] % 2 != 0:
        expend_H = True
        image = np.concatenate((image, image[:, :, -1][:, :, np.newaxis]),
                               axis=2)

    # Noise
    image = torch.FloatTensor([image])  # 1 * C(1 / 3) * W * H
    if args.add_noise:
        image = utils.add_batch_noise(image, args.noise_sigma)
    noise_sigma = torch.FloatTensor([args.noise_sigma])

    # Model & GPU
    model = FFDNet(is_gray=is_gray)
    if args.cuda:
        image = image.cuda()
        noise_sigma = noise_sigma.cuda()
        model = model.cuda()

    # Dict
    model_path = args.model_path + ('net_gray.pth'
                                    if is_gray else 'net_rgb.pth')
    print(f"> Loading model param in {model_path}...")
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.eval()
    print('\n')

    # Test
    with torch.no_grad():
        start_time = time.time()
        image_pred = model(image, noise_sigma)
        stop_time = time.time()
        print("Test time: {0:.4f}s".format(stop_time - start_time))

    # PSNR
    psnr = utils.batch_psnr(img=image_pred, imclean=image, data_range=1)
    print("PSNR denoised {0:.2f}dB".format(psnr))

    # UnExpand odd
    if expend_W:
        image_pred = image_pred[:, :, :-1, :]
    if expend_H:
        image_pred = image_pred[:, :, :, :-1]

    # Save
    cv2.imwrite("ffdnet.png", utils.variable_to_cv2_image(image_pred))
    if args.add_noise:
        cv2.imwrite("noisy.png", utils.variable_to_cv2_image(image))
Exemple #9
0
                seq_time = time.time()

                # Add noise
                noise = torch.empty_like(seq).normal_(
                    mean=0, std=args.noise_sigma).to(device)
                seqn = seq + noise
                noisestd = torch.FloatTensor([args.noise_sigma]).to(device)

                denframes = denoise_seq_fastdvdnet(seq=seq,
                                                   noise_std=noisestd,
                                                   temp_psz=NUM_IN_FR_EXT,
                                                   model_temporal=model_temp)

                # Compute PSNR and log it
                stop_time = time.time()
                psnr = batch_psnr(denframes, seq, 1.)
                psnr_noisy = batch_psnr(seqn.squeeze(), seq, 1.)
                loadtime = (seq_time - start_time)
                runtime = (stop_time - seq_time)
                seq_length = seq.size()[0]

                print("\tDenoised {} frames in {:.3f}s, loaded seq in {:.3f}s".
                      format(seq_length, runtime, loadtime))
                print("\tPSNR noisy {:.4f}dB, PSNR result {:.4f}dB".format(
                    psnr_noisy, psnr))

                # Save outputs
                seq_len = len(seq_list)
                for idx in range(seq_len):
                    out_name = os.path.join(args.save_path, seq_outnames[idx])
                    print("Saving %s" % out_name)