Beispiel #1
0
    def __init__(self,
                 config_file,
                 im_size,
                 config_spec=None,
                 cropping="random",
                 cache_dir=None,
                 use_cache=False,
                 dataset_name="default_dataset_name"):
        """ Dataset for generating degraded images on the fly.

        Args:
            pipeline_configs: dictionary of boolean flags controlling how
                pipelines are created.
            pipeline_param_ranges: dictionary of ranges of params.
            patch_dir: directory to load linear patches.
            config_file: path to data config file
            im_size: tuple of (w, h)
            config_spec: path to data config spec file
            cropping: cropping mode ["random", "center"]
        """
        super().__init__()
        if config_spec is None:
            config_spec = _configspec_path()
        config = read_config(config_file, config_spec)
        self.config_file = config_file
        # directory to load linear patches
        patch_dir = config['dataset_dir']
        # dictionary of boolean flags controlling how pipelines are created
        # (see data_configspec for detail).
        pipeline_configs = config['pipeline_configs']
        # dictionary of ranges of params (see data_configspec for detail).
        pipeline_param_ranges = config['pipeline_param_ranges']

        file_list = glob.glob(os.path.join(patch_dir, 'images/target/*.npy'))
        file_list = [os.path.basename(f) for f in file_list]
        file_list = [os.path.splitext(f)[0] for f in file_list]
        self.file_list = sorted(file_list,
                                key=lambda x: zlib.adler32(x.encode('utf-8')))

        self.pipeline_param_ranges = pipeline_param_ranges
        self.pipeline_configs = pipeline_configs
        print('Data Pipeline Configs: ', self.pipeline_configs)
        print('Data Pipeline Param Ranges: ', self.pipeline_param_ranges)
        self.data_root = patch_dir
        self.im_size = im_size
        self.cropping = cropping
        self.use_cache = use_cache
        self.cache_dir = cache_dir

        sz = "{}x{}".format(self.im_size[0], self.im_size[1]) \
                if self.im_size is not None else "None"
        self.dataset_name = "_".join([dataset_name, sz])
    def __init__(self,
                 config_file,
                 config_spec=None,
                 img_format='.bmp',
                 degamma=True,
                 color=True,
                 blind=False,
                 train=True):
        super(TrainDataSet, self).__init__()
        if config_spec is None:
            config_spec = self._configspec_path()
        config = read_config(config_file, config_spec)
        self.dataset_config = config['dataset_configs']
        self.dataset_dir = self.dataset_config['dataset_dir']
        self.images = list(
            filter(lambda x: True if img_format in x else False,
                   os.listdir(self.dataset_dir)))
        self.burst_size = self.dataset_config['burst_length']
        self.patch_size = self.dataset_config['patch_size']

        self.upscale = self.dataset_config['down_sample']
        self.big_jitter = self.dataset_config['big_jitter']
        self.small_jitter = self.dataset_config['small_jitter']
        # 对应下采样之前图像的最大偏移量
        self.jitter_upscale = self.big_jitter * self.upscale
        # 对应下采样之前的图像的patch尺寸
        self.size_upscale = self.patch_size * self.upscale + 2 * self.jitter_upscale
        # 产生大jitter和小jitter之间的delta  在下采样之前的尺度上
        self.delta_upscale = (self.big_jitter -
                              self.small_jitter) * self.upscale
        # 对应到原图的patch的尺寸
        self.patch_size_upscale = self.patch_size * self.upscale
        # 去伽马效应
        self.degamma = degamma
        # 是否用彩色图像进行处理
        self.color = color
        # 是否盲估计  盲估计即估计的噪声方差不会作为网络的输入
        self.blind = blind
        self.train = train

        self.vertical_flip = Random_Vertical_Flip(p=0.5)
        self.horizontal_flip = Random_Horizontal_Flip(p=0.5)
Beispiel #3
0
def eval(config, args):
    train_config = config['training']
    arch_config = config['architecture']

    use_cache = train_config['use_cache']

    print('Eval Process......')

    checkpoint_dir = train_config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # the path for saving eval images
    eval_dir = train_config['eval_dir']
    if not os.path.exists(eval_dir):
        os.mkdir(eval_dir)
    files = os.listdir(eval_dir)
    for f in files:
        os.remove(os.path.join(eval_dir, f))

    # dataset and dataloader
    data_set = TrainDataSet(train_config['dataset_configs'],
                            img_format='.bmp',
                            degamma=True,
                            color=False,
                            blind=arch_config['blind_est'],
                            train=False)
    data_loader = DataLoader(data_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.num_workers)

    dataset_config = read_config(train_config['dataset_configs'],
                                 _configspec_path())['dataset_configs']

    # model here
    model = KPN(color=False,
                burst_length=dataset_config['burst_length'],
                blind_est=arch_config['blind_est'],
                kernel_size=list(map(int, arch_config['kernel_size'].split())),
                sep_conv=arch_config['sep_conv'],
                channel_att=arch_config['channel_att'],
                spatial_att=arch_config['spatial_att'],
                upMode=arch_config['upMode'],
                core_bias=arch_config['core_bias'])
    if args.cuda:
        model = model.cuda()

    if args.mGPU:
        model = nn.DataParallel(model)
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir, args.checkpoint)
    model.load_state_dict(ckpt['state_dict'])
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.eval()

    # data_loader = iter(data_loader)
    burst_length = dataset_config['burst_length']
    data_length = burst_length if arch_config['blind_est'] else burst_length + 1
    patch_size = dataset_config['patch_size']

    trans = transforms.ToPILImage()

    with torch.no_grad():
        psnr = 0.0
        ssim = 0.0
        for i, (burst_noise, gt, white_level) in enumerate(data_loader):
            if i < 100:
                # data = next(data_loader)
                if args.cuda:
                    burst_noise = burst_noise.cuda()
                    gt = gt.cuda()
                    white_level = white_level.cuda()

                pred_i, pred = model(burst_noise,
                                     burst_noise[:, 0:burst_length,
                                                 ...], white_level)

                pred_i = sRGBGamma(pred_i)
                pred = sRGBGamma(pred)
                gt = sRGBGamma(gt)
                burst_noise = sRGBGamma(burst_noise / white_level)

                psnr_t = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1))
                ssim_t = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1))
                psnr_noisy = calculate_psnr(
                    burst_noise[:, 0, ...].unsqueeze(1), gt.unsqueeze(1))
                psnr += psnr_t
                ssim += ssim_t

                pred = torch.clamp(pred, 0.0, 1.0)

                if args.cuda:
                    pred = pred.cpu()
                    gt = gt.cpu()
                    burst_noise = burst_noise.cpu()

                trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join(
                    eval_dir, '{}_noisy_{:.2f}dB.png'.format(i, psnr_noisy)),
                                                             quality=100)
                trans(pred.squeeze()).save(os.path.join(
                    eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)),
                                           quality=100)
                trans(gt.squeeze()).save(os.path.join(eval_dir,
                                                      '{}_gt.png'.format(i)),
                                         quality=100)

                print('{}-th image is OK, with PSNR: {:.2f}dB, SSIM: {:.4f}'.
                      format(i, psnr_t, ssim_t))
            else:
                break
        print('All images are OK, average PSNR: {:.2f}dB, SSIM: {:.4f}'.format(
            psnr / 100, ssim / 100))
Beispiel #4
0
def train(config, num_workers, num_threads, cuda, restart_train, mGPU):
    # torch.set_num_threads(num_threads)

    train_config = config['training']
    arch_config = config['architecture']

    batch_size = train_config['batch_size']
    lr = train_config['learning_rate']
    weight_decay = train_config['weight_decay']
    decay_step = train_config['decay_steps']
    lr_decay = train_config['lr_decay']

    n_epoch = train_config['num_epochs']
    use_cache = train_config['use_cache']

    print('Configs:', config)
    # checkpoint path
    checkpoint_dir = train_config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    # logs path
    logs_dir = train_config['logs_dir']
    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir)
    shutil.rmtree(logs_dir)
    log_writer = SummaryWriter(logs_dir)

    # dataset and dataloader
    data_set = TrainDataSet(train_config['dataset_configs'],
                            img_format='.bmp',
                            degamma=True,
                            color=False,
                            blind=arch_config['blind_est'])
    data_loader = DataLoader(data_set,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=num_workers)
    dataset_config = read_config(train_config['dataset_configs'],
                                 _configspec_path())['dataset_configs']

    # model here
    model = KPN(color=False,
                burst_length=dataset_config['burst_length'],
                blind_est=arch_config['blind_est'],
                kernel_size=list(map(int, arch_config['kernel_size'].split())),
                sep_conv=arch_config['sep_conv'],
                channel_att=arch_config['channel_att'],
                spatial_att=arch_config['spatial_att'],
                upMode=arch_config['upMode'],
                core_bias=arch_config['core_bias'])
    if cuda:
        model = model.cuda()

    if mGPU:
        model = nn.DataParallel(model)
    model.train()

    # loss function here
    loss_func = LossFunc(coeff_basic=1.0,
                         coeff_anneal=1.0,
                         gradient_L1=True,
                         alpha=arch_config['alpha'],
                         beta=arch_config['beta'])

    # Optimizer here
    if train_config['optimizer'] == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif train_config['optimizer'] == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=lr,
                              momentum=0.9,
                              weight_decay=weight_decay)
    else:
        raise ValueError(
            "Optimizer must be 'sgd' or 'adam', but received {}.".format(
                train_config['optimizer']))
    optimizer.zero_grad()

    # learning rate scheduler here
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=lr_decay)

    average_loss = MovingAverage(train_config['save_freq'])
    if not restart_train:
        try:
            checkpoint = load_checkpoint(checkpoint_dir, 'best')
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_iter']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['lr_scheduler'])
            print('=> loaded checkpoint (epoch {}, global_step {})'.format(
                start_epoch, global_step))
        except:
            start_epoch = 0
            global_step = 0
            best_loss = np.inf
            print('=> no checkpoint file to be loaded.')
    else:
        start_epoch = 0
        global_step = 0
        best_loss = np.inf
        if os.path.exists(checkpoint_dir):
            pass
            # files = os.listdir(checkpoint_dir)
            # for f in files:
            #     os.remove(os.path.join(checkpoint_dir, f))
        else:
            os.mkdir(checkpoint_dir)
        print('=> training')

    burst_length = dataset_config['burst_length']
    data_length = burst_length if arch_config['blind_est'] else burst_length + 1
    patch_size = dataset_config['patch_size']

    for epoch in range(start_epoch, n_epoch):
        epoch_start_time = time.time()
        # decay the learning rate
        lr_cur = [param['lr'] for param in optimizer.param_groups]
        if lr_cur[0] > 5e-6:
            scheduler.step()
        else:
            for param in optimizer.param_groups:
                param['lr'] = 5e-6
        print(
            '=' * 20,
            'lr={}'.format([param['lr'] for param in optimizer.param_groups]),
            '=' * 20)
        t1 = time.time()
        for step, (burst_noise, gt, white_level) in enumerate(data_loader):
            if cuda:
                burst_noise = burst_noise.cuda()
                gt = gt.cuda()
            # print('white_level', white_level, white_level.size())

            #
            pred_i, pred = model(burst_noise, burst_noise[:, 0:burst_length,
                                                          ...], white_level)

            #
            loss_basic, loss_anneal = loss_func(sRGBGamma(pred_i),
                                                sRGBGamma(pred), sRGBGamma(gt),
                                                global_step)
            loss = loss_basic + loss_anneal
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update the average loss
            average_loss.update(loss)
            # calculate PSNR
            psnr = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1))
            ssim = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1))

            # add scalars to tensorboardX
            log_writer.add_scalar('loss_basic', loss_basic, global_step)
            log_writer.add_scalar('loss_anneal', loss_anneal, global_step)
            log_writer.add_scalar('loss_total', loss, global_step)
            log_writer.add_scalar('psnr', psnr, global_step)
            log_writer.add_scalar('ssim', ssim, global_step)

            # print
            print(
                '{:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t| loss_anneal: {:.4f}\t|'
                ' loss: {:.4f}\t| PSNR: {:.2f}dB\t| SSIM: {:.4f}\t| time:{:.2f} seconds.'
                .format(global_step, epoch, step, loss_basic, loss_anneal,
                        loss, psnr, ssim,
                        time.time() - t1))
            t1 = time.time()
            # global_step
            global_step += 1

            if global_step % train_config['save_freq'] == 0:
                if average_loss.get_value() < best_loss:
                    is_best = True
                    best_loss = average_loss.get_value()
                else:
                    is_best = False

                save_dict = {
                    'epoch': epoch,
                    'global_iter': global_step,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': scheduler.state_dict()
                }
                save_checkpoint(save_dict,
                                is_best,
                                checkpoint_dir,
                                global_step,
                                max_keep=train_config['ckpt_to_keep'])

        print('Epoch {} is finished, time elapsed {:.2f} seconds.'.format(
            epoch,
            time.time() - epoch_start_time))
Beispiel #5
0
    parser.add_argument('--num_threads',
                        '-nt',
                        default=8,
                        type=int,
                        help='number of threads in data loader')
    parser.add_argument('--cuda',
                        '-c',
                        action='store_true',
                        help='whether to train on the GPU')
    parser.add_argument('--mGPU',
                        '-m',
                        action='store_true',
                        help='whether to train on multiple GPUs')
    parser.add_argument('--eval',
                        action='store_true',
                        help='whether to work on the evaluation mode')
    parser.add_argument('--checkpoint',
                        '-ckpt',
                        dest='checkpoint',
                        type=str,
                        default='best',
                        help='the checkpoint to eval')
    args = parser.parse_args()
    #
    config = read_config(args.config_file, args.config_spec)
    if args.eval:
        eval(config, args)
    else:
        train(config, args.num_workers, args.num_threads, args.cuda,
              args.restart, args.mGPU)
Beispiel #6
0
    def __init__(self,
                 config_file,
                 config_spec=None,
                 blind=False,
                 cropping="random",
                 cache_dir=None,
                 use_cache=False,
                 dataset_name="synthetic"):
        """ Dataset for generating degraded images on the fly.

        Args:
            pipeline_configs: dictionary of boolean flags controlling how
                pipelines are created.
            pipeline_param_ranges: dictionary of ranges of params.
            patch_dir: directory to load linear patches.
            config_file: path to data config file
            im_size: tuple of (w, h)
            config_spec: path to data config spec file
            cropping: cropping mode ["random", "center"]
        """
        super().__init__()
        if config_spec is None:
            config_spec = _configspec_path()
        config = read_config(config_file, config_spec)
        # self.config_file = config_file
        # dictionary of dataset configs
        self.dataset_configs = config['dataset_configs']
        # directory to load linear patches
        patch_dir = self.dataset_configs['dataset_dir']
        # dictionary of boolean flags controlling how pipelines are created
        # (see data_configspec for detail).
        self.pipeline_configs = config['pipeline_configs']
        # dictionary of ranges of params (see data_configspec for detail).
        self.pipeline_param_ranges = config['pipeline_param_ranges']

        file_list = glob.glob(os.path.join(patch_dir,
                                           '*.pth'))
        file_list = [os.path.basename(f) for f in file_list]
        file_list = [os.path.splitext(f)[0] for f in file_list]
        self.file_list = sorted(file_list, key=lambda x: zlib.adler32(x.encode('utf-8')))
        # print(self.file_list)
        # self.pipeline_param_ranges = pipeline_param_ranges
        # self.pipeline_configs = pipeline_configs
        # print('Data Pipeline Configs: ', self.pipeline_configs)
        # print('Data Pipeline Param Ranges: ', self.pipeline_param_ranges)
        # some variables about the setting of dataset
        self.data_root = patch_dir
        self.im_size = self.dataset_configs['patch_size']  # the size after down-sample
        extra_for_bayer = 2  # extra size used for the random choice for bayer pattern
        self.big_jitter = self.dataset_configs['big_jitter']
        self.small_jitter = self.dataset_configs['small_jitter']
        self.down_sample = self.dataset_configs['down_sample']
        # image size corresponding to original image (include big jitter)
        self.im_size_upscale = (self.im_size + 2 * self.big_jitter + extra_for_bayer) * self.down_sample
        # from big jitter image to real image with extra pixels to random choose the bayer pattern
        self.big_restore_upscale = self.big_jitter * self.down_sample
        # the shift pixels of small jitter within upscale
        self.small_restore_upscale = self.small_jitter * self.down_sample
        # from big jitter images to small jitter images
        self.big2small_upscale = (self.big_jitter - self.small_jitter) * self.down_sample
        #
        self.im_size_extra = (self.im_size + extra_for_bayer) * self.down_sample
        # blind estimate?
        self.blind = blind
        # others
        self.cropping = cropping
        self.use_cache = use_cache
        self.cache_dir = cache_dir

        sz = "{}x{}".format(self.im_size, self.im_size) \
            if self.im_size is not None else "None"
        self.dataset_name = "_".join([dataset_name, sz])

        # add the codes by Bin Zhang
        self.burst_length = self.dataset_configs['burst_length']
Beispiel #7
0
def train():
    log_writer = SummaryWriter('./logs')
    parser = argparse.ArgumentParser()
    parser.add_argument('--restart', '-r', action='store_true')
    args = parser.parse_args()

    config = read_config('kpn_specs/att_kpn_config.conf', 'kpn_specs/configspec.conf')
    train_config = config['training']
    data_set = TrainDataSet(
        train_config['dataset_configs'],
        img_format='.bmp',
        degamma=True,
        color=True,
        blind=False
    )
    data_loader = DataLoader(
        dataset=data_set,
        batch_size=32,
        shuffle=True,
        num_workers=4
    )

    loss_fn = nn.L1Loss()

    model = Network(True).cuda()

    model.train()

    optimizer = optim.Adam(model.parameters(), lr=5e-5)

    if not args.restart:
        model.load_state_dict(load_checkpoint('./noise_models', best_or_latest='best'))
    global_iter = 0
    min_loss = np.inf
    loss_ave = MovingAverage(200)

    import os
    if not os.path.exists('./noise_models'):
        os.mkdir('./noise_models')

    for epoch in range(100):
        for step, (data, A, B) in enumerate(data_loader):
            feed = data[:, 0, ...].cuda()
            gt = data[:, -1, ...].cuda()
            # print(data.size())
            pred = model(feed)

            loss = loss_fn(pred, gt)

            global_iter += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            log_writer.add_scalar('loss', loss, global_iter)

            loss_ave.update(loss)
            if global_iter % 200 == 0:
                loss_t = loss_ave.get_value()
                min_loss = min(min_loss, loss_t)
                is_best = min_loss == loss_t
                save_checkpoint(
                    model.state_dict(),
                    is_best=is_best,
                    checkpoint_dir='./noise_models',
                    n_iter=global_iter
                )
            print('{: 6d}, epoch {: 3d}, iter {: 4d}, loss {:.4f}'.format(global_iter, epoch, step, loss))